diff --git a/pysr/__init__.py b/pysr/__init__.py index e26174ab..aabbb669 100644 --- a/pysr/__init__.py +++ b/pysr/__init__.py @@ -1,5 +1,12 @@ +import logging import os +pysr_logger = logging.getLogger("pysr") +pysr_logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +pysr_logger.addHandler(handler) + if os.environ.get("PYSR_USE_BEARTYPE", "0") == "1": from beartype.claw import beartype_this_package diff --git a/pysr/feature_selection.py b/pysr/feature_selection.py index 8c8358fd..13fca487 100644 --- a/pysr/feature_selection.py +++ b/pysr/feature_selection.py @@ -1,5 +1,6 @@ """Functions for doing feature selection during preprocessing.""" +import logging from typing import cast import numpy as np @@ -8,6 +9,8 @@ from .utils import ArrayLike +pysr_logger = logging.getLogger(__name__) + def run_feature_selection( X: ndarray, @@ -44,7 +47,7 @@ def _handle_feature_selection( ): if select_k_features is not None: selection = run_feature_selection(X, y, select_k_features) - print(f"Using features {[variable_names[i] for i in selection]}") + pysr_logger.info(f"Using features {[variable_names[i] for i in selection]}") X = X[:, selection] else: selection = None diff --git a/pysr/sr.py b/pysr/sr.py index f048d8d7..05abd2a3 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -1,6 +1,7 @@ """Define the PySRRegressor scikit-learn interface.""" import copy +import logging import os import pickle as pkl import re @@ -67,6 +68,8 @@ ALREADY_RAN = False +pysr_logger = logging.getLogger(__name__) + def _process_constraints( binary_operators: list[str], @@ -1107,7 +1110,7 @@ def from_file( pkl_filename = Path(run_directory) / "checkpoint.pkl" if pkl_filename.exists(): - print(f"Attempting to load model from {pkl_filename}...") + pysr_logger.info(f"Attempting to load model from {pkl_filename}...") assert binary_operators is None assert unary_operators is None assert n_features_in is None @@ -1129,9 +1132,9 @@ def from_file( return model else: - print( - f"Checkpoint file {pkl_filename} does not exist. " - "Attempting to recreate model from scratch..." + pysr_logger.info( + "Checkpoint file %s does not exist. Attempting to recreate model from scratch...", + pkl_filename, ) csv_filename = Path(run_directory) / "hall_of_fame.csv" csv_filename_bak = Path(run_directory) / "hall_of_fame.csv.bak" @@ -1232,12 +1235,16 @@ def __getstate__(self) -> dict[str, Any]: ) state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"] for state_key in state_keys_containing_lambdas: - if state[state_key] is not None and show_pickle_warning: - warnings.warn( - f"`{state_key}` cannot be pickled and will be removed from the " - "serialized instance. When loading the model, please redefine " - f"`{state_key}` at runtime." - ) + warn_msg = ( + f"`{state_key}` cannot be pickled and will be removed from the " + "serialized instance. When loading the model, please redefine " + f"`{state_key}` at runtime." + ) + if state[state_key] is not None: + if show_pickle_warning: + warnings.warn(warn_msg) + else: + pysr_logger.debug(warn_msg) state_keys_to_clear = state_keys_containing_lambdas state_keys_to_clear.append("logger_") pickled_state = { @@ -1280,7 +1287,7 @@ def _checkpoint(self): try: pkl.dump(self, f) except Exception as e: - print(f"Error checkpointing model: {e}") + pysr_logger.debug(f"Error checkpointing model: {e}") self.show_pickle_warnings_ = True def get_pkl_filename(self) -> Path: @@ -1752,7 +1759,7 @@ def _pre_transform_training_data( self.selection_mask_ = selection_mask self.feature_names_in_ = _check_feature_names_in(self, variable_names) self.display_feature_names_in_ = self.feature_names_in_ - print(f"Using features {self.feature_names_in_}") + pysr_logger.info(f"Using features {self.feature_names_in_}") # Denoising transformation if self.denoise: @@ -1824,7 +1831,7 @@ def _run( # Start julia backend processes if not ALREADY_RAN and runtime_params.update_verbosity != 0: - print("Compiling Julia backend...") + pysr_logger.info("Compiling Julia backend...") parallelism, numprocs = _map_parallelism_params( self.parallelism, self.procs, getattr(self, "multithreading", None)