Skip to content

Commit

Permalink
refactor: clean up tests of mpi
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Aug 13, 2024
1 parent 5365c6b commit 5e9b679
Showing 1 changed file with 36 additions and 31 deletions.
67 changes: 36 additions & 31 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,38 +85,43 @@ def test_linear_relation_weighted_bumper(self):
jl.seval("((::Val{x}) where x) -> x")(model.julia_options_.bumper), True
)

def _multiprocessing_turbo_custom_objective(self, cluster_manager):
rstate = np.random.RandomState(0)
y = self.X[:, 0]
y += rstate.randn(*y.shape) * 1e-4
model = PySRRegressor(
**self.default_test_kwargs,
# Turbo needs to work with unsafe operators:
unary_operators=["sqrt"],
procs=2,
multithreading=False,
turbo=True,
early_stop_condition="stop_if(loss, complexity) = loss < 1e-10 && complexity == 1",
loss_function="""
function my_objective(tree::Node{T}, dataset::Dataset{T}, options::Options) where T
prediction, flag = eval_tree_array(tree, dataset.X, options)
!flag && return T(Inf)
abs3(x) = abs(x) ^ 3
return sum(abs3, prediction .- dataset.y) / length(prediction)
end
""",
)
model.fit(self.X, y)
print(model.equations_)
best_loss = model.equations_.iloc[-1]["loss"]
self.assertLessEqual(best_loss, 1e-10)
self.assertGreaterEqual(best_loss, 0.0)

# Test options stored:
self.assertEqual(
jl.seval("((::Val{x}) where x) -> x")(model.julia_options_.turbo), True
)

def test_multiprocessing_turbo_custom_objective(self):
for cluster_manager in (None, "mpi"):
rstate = np.random.RandomState(0)
y = self.X[:, 0]
y += rstate.randn(*y.shape) * 1e-4
model = PySRRegressor(
**self.default_test_kwargs,
# Turbo needs to work with unsafe operators:
unary_operators=["sqrt"],
procs=2,
multithreading=False,
turbo=True,
early_stop_condition="stop_if(loss, complexity) = loss < 1e-10 && complexity == 1",
loss_function="""
function my_objective(tree::Node{T}, dataset::Dataset{T}, options::Options) where T
prediction, flag = eval_tree_array(tree, dataset.X, options)
!flag && return T(Inf)
abs3(x) = abs(x) ^ 3
return sum(abs3, prediction .- dataset.y) / length(prediction)
end
""",
)
model.fit(self.X, y)
print(model.equations_)
best_loss = model.equations_.iloc[-1]["loss"]
self.assertLessEqual(best_loss, 1e-10)
self.assertGreaterEqual(best_loss, 0.0)

# Test options stored:
self.assertEqual(
jl.seval("((::Val{x}) where x) -> x")(model.julia_options_.turbo), True
)
self._multiprocessing_turbo_custom_objective(None)

def test_multiprocessing_turbo_custom_objective_mpi(self):
self._multiprocessing_turbo_custom_objective("mpi")

def test_multiline_seval(self):
# The user should be able to run multiple things in a single seval call:
Expand Down

0 comments on commit 5e9b679

Please sign in to comment.