Skip to content

Commit

Permalink
feat: use proper python logging
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Feb 23, 2025
1 parent b03e02b commit 48bc582
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 14 deletions.
7 changes: 7 additions & 0 deletions pysr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import logging
import os

pysr_logger = logging.getLogger("pysr")
pysr_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
pysr_logger.addHandler(handler)

if os.environ.get("PYSR_USE_BEARTYPE", "0") == "1":
from beartype.claw import beartype_this_package

Expand Down
5 changes: 4 additions & 1 deletion pysr/feature_selection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Functions for doing feature selection during preprocessing."""

import logging
from typing import cast

import numpy as np
Expand All @@ -8,6 +9,8 @@

from .utils import ArrayLike

pysr_logger = logging.getLogger(__name__)


def run_feature_selection(
X: ndarray,
Expand Down Expand Up @@ -44,7 +47,7 @@ def _handle_feature_selection(
):
if select_k_features is not None:
selection = run_feature_selection(X, y, select_k_features)
print(f"Using features {[variable_names[i] for i in selection]}")
pysr_logger.info(f"Using features {[variable_names[i] for i in selection]}")
X = X[:, selection]
else:
selection = None
Expand Down
33 changes: 20 additions & 13 deletions pysr/sr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Define the PySRRegressor scikit-learn interface."""

import copy
import logging
import os
import pickle as pkl
import re
Expand Down Expand Up @@ -67,6 +68,8 @@

ALREADY_RAN = False

pysr_logger = logging.getLogger(__name__)


def _process_constraints(
binary_operators: list[str],
Expand Down Expand Up @@ -1107,7 +1110,7 @@ def from_file(

pkl_filename = Path(run_directory) / "checkpoint.pkl"
if pkl_filename.exists():
print(f"Attempting to load model from {pkl_filename}...")
pysr_logger.info(f"Attempting to load model from {pkl_filename}...")
assert binary_operators is None
assert unary_operators is None
assert n_features_in is None
Expand All @@ -1129,9 +1132,9 @@ def from_file(

return model
else:
print(
f"Checkpoint file {pkl_filename} does not exist. "
"Attempting to recreate model from scratch..."
pysr_logger.info(
"Checkpoint file %s does not exist. Attempting to recreate model from scratch...",
pkl_filename,
)
csv_filename = Path(run_directory) / "hall_of_fame.csv"
csv_filename_bak = Path(run_directory) / "hall_of_fame.csv.bak"
Expand Down Expand Up @@ -1232,12 +1235,16 @@ def __getstate__(self) -> dict[str, Any]:
)
state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
for state_key in state_keys_containing_lambdas:
if state[state_key] is not None and show_pickle_warning:
warnings.warn(
f"`{state_key}` cannot be pickled and will be removed from the "
"serialized instance. When loading the model, please redefine "
f"`{state_key}` at runtime."
)
warn_msg = (
f"`{state_key}` cannot be pickled and will be removed from the "
"serialized instance. When loading the model, please redefine "
f"`{state_key}` at runtime."
)
if state[state_key] is not None:
if show_pickle_warning:
warnings.warn(warn_msg)
else:
pysr_logger.debug(warn_msg)
state_keys_to_clear = state_keys_containing_lambdas
state_keys_to_clear.append("logger_")
pickled_state = {
Expand Down Expand Up @@ -1280,7 +1287,7 @@ def _checkpoint(self):
try:
pkl.dump(self, f)
except Exception as e:
print(f"Error checkpointing model: {e}")
pysr_logger.debug(f"Error checkpointing model: {e}")
self.show_pickle_warnings_ = True

def get_pkl_filename(self) -> Path:
Expand Down Expand Up @@ -1752,7 +1759,7 @@ def _pre_transform_training_data(
self.selection_mask_ = selection_mask
self.feature_names_in_ = _check_feature_names_in(self, variable_names)
self.display_feature_names_in_ = self.feature_names_in_
print(f"Using features {self.feature_names_in_}")
pysr_logger.info(f"Using features {self.feature_names_in_}")

# Denoising transformation
if self.denoise:
Expand Down Expand Up @@ -1824,7 +1831,7 @@ def _run(

# Start julia backend processes
if not ALREADY_RAN and runtime_params.update_verbosity != 0:
print("Compiling Julia backend...")
pysr_logger.info("Compiling Julia backend...")

parallelism, numprocs = _map_parallelism_params(
self.parallelism, self.procs, getattr(self, "multithreading", None)
Expand Down

0 comments on commit 48bc582

Please sign in to comment.