diff --git a/pysr/julia_extensions.py b/pysr/julia_extensions.py index 72273f3e..950c292e 100644 --- a/pysr/julia_extensions.py +++ b/pysr/julia_extensions.py @@ -3,6 +3,7 @@ from typing import Literal from .julia_import import Pkg, jl +from .julia_registry_helpers import try_with_registry_fallback from .logger_specs import AbstractLoggerSpec, TensorBoardLoggerSpec @@ -47,8 +48,12 @@ def isinstalled(uuid_s: str): def load_package(package_name: str, uuid_s: str) -> None: if not isinstalled(uuid_s): - Pkg.add(name=package_name, uuid=uuid_s) - Pkg.resolve() + + def _add_package(): + Pkg.add(name=package_name, uuid=uuid_s) + Pkg.resolve() + + try_with_registry_fallback(_add_package) # TODO: Protect against loading the same symbol from two packages, # maybe with a @gensym here. diff --git a/pysr/julia_import.py b/pysr/julia_import.py index 4d7b9150..4ea6b88d 100644 --- a/pysr/julia_import.py +++ b/pysr/julia_import.py @@ -4,6 +4,8 @@ from types import ModuleType from typing import cast +from .julia_registry_helpers import try_with_registry_fallback + # Check if JuliaCall is already loaded, and if so, warn the user # about the relevant environment variables. If not loaded, # set up sensible defaults. @@ -42,6 +44,14 @@ # Deprecated; so just pass to juliacall os.environ["PYTHON_JULIACALL_AUTOLOAD_IPYTHON_EXTENSION"] = autoload_extensions + +def _import_juliacall(): + import juliacall # type: ignore + + +try_with_registry_fallback(_import_juliacall) + + from juliacall import AnyValue # type: ignore from juliacall import VectorValue # type: ignore from juliacall import Main as jl # type: ignore diff --git a/pysr/julia_registry_helpers.py b/pysr/julia_registry_helpers.py new file mode 100644 index 00000000..2c2162e7 --- /dev/null +++ b/pysr/julia_registry_helpers.py @@ -0,0 +1,44 @@ +"""Utilities for managing Julia registry preferences during package operations.""" + +import os +import warnings +from collections.abc import Callable +from typing import TypeVar + +T = TypeVar("T") + +PREFERENCE_KEY = "JULIA_PKG_SERVER_REGISTRY_PREFERENCE" + + +def try_with_registry_fallback(f: Callable[..., T], *args, **kwargs) -> T: + """Execute function with modified Julia registry preference. + + First tries with existing registry preference. If that fails with a Julia registry error, + temporarily modifies the registry preference to 'eager'. Restores original preference after + execution. + """ + try: + return f(*args, **kwargs) + except Exception as initial_error: + # Check if this is a Julia registry error by looking at the error message + if "JuliaError" not in str( + type(initial_error) + ) or "Unsatisfiable requirements detected" not in str(initial_error): + raise initial_error + + old_value = os.environ.get(PREFERENCE_KEY, None) + if old_value == "eager": + raise initial_error + + warnings.warn( + "Initial Julia registry operation failed. Attempting to use the `eager` registry flavor of the Julia " + + f"General registry from the Julia Pkg server (via the `{PREFERENCE_KEY}` environment variable)." + ) + os.environ[PREFERENCE_KEY] = "eager" + try: + return f(*args, **kwargs) + finally: + if old_value is not None: + os.environ[PREFERENCE_KEY] = old_value + else: + del os.environ[PREFERENCE_KEY] diff --git a/pysr/test/test_startup.py b/pysr/test/test_startup.py index a92010ce..4b2a450b 100644 --- a/pysr/test/test_startup.py +++ b/pysr/test/test_startup.py @@ -9,8 +9,9 @@ import numpy as np -from pysr import PySRRegressor +from pysr import PySRRegressor, jl from pysr.julia_import import jl_version +from pysr.julia_registry_helpers import PREFERENCE_KEY, try_with_registry_fallback from .params import DEFAULT_NITERATIONS, DEFAULT_POPULATIONS @@ -159,8 +160,73 @@ def test_notebook(self): self.assertEqual(result.returncode, 0) +class TestRegistryHelper(unittest.TestCase): + """Test the custom Julia registry preference handling.""" + + def setUp(self): + self.old_value = os.environ.get(PREFERENCE_KEY, None) + self.recorded_env_vars = [] + self.hits = 0 + + def failing_operation(): + self.recorded_env_vars.append(os.environ[PREFERENCE_KEY]) + self.hits += 1 + # Just add some package I know will not exist and also not be in the dependency chain: + jl.Pkg.add(name="AirspeedVelocity", version="100.0.0") + + self.failing_operation = failing_operation + + def tearDown(self): + if self.old_value is not None: + os.environ[PREFERENCE_KEY] = self.old_value + else: + os.environ.pop(PREFERENCE_KEY, None) + + def test_successful_operation(self): + self.assertEqual(try_with_registry_fallback(lambda s: s, "success"), "success") + + def test_non_julia_errors_reraised(self): + with self.assertRaises(SyntaxError) as context: + try_with_registry_fallback(lambda: exec("invalid syntax !@#$")) + self.assertNotIn("JuliaError", str(context.exception)) + + def test_julia_error_triggers_fallback(self): + os.environ[PREFERENCE_KEY] = "conservative" + + with self.assertWarns(Warning) as warn_context: + with self.assertRaises(Exception) as error_context: + try_with_registry_fallback(self.failing_operation) + + self.assertIn( + "Unsatisfiable requirements detected", str(error_context.exception) + ) + self.assertIn( + "Initial Julia registry operation failed. Attempting to use the `eager` registry flavor of the Julia", + str(warn_context.warning), + ) + + # Verify both modes are tried in order + self.assertEqual(self.recorded_env_vars, ["conservative", "eager"]) + self.assertEqual(self.hits, 2) + + # Verify environment is restored + self.assertEqual(os.environ[PREFERENCE_KEY], "conservative") + + def test_eager_mode_fails_directly(self): + os.environ[PREFERENCE_KEY] = "eager" + + with self.assertRaises(Exception) as context: + try_with_registry_fallback(self.failing_operation) + + self.assertIn("Unsatisfiable requirements detected", str(context.exception)) + self.assertEqual( + self.recorded_env_vars, ["eager"] + ) # Should only try eager mode + self.assertEqual(self.hits, 1) + + def runtests(just_tests=False): - tests = [TestStartup] + tests = [TestStartup, TestRegistryHelper] if just_tests: return tests suite = unittest.TestSuite()