Skip to content

Commit

Permalink
add more pytest on the completely_random_rotation; #59
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Oct 24, 2024
1 parent 8826ea1 commit 93ed961
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/test_random_state_reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,34 @@ def test_random_state_reproducibility():
assert np.sum(ensemble_df1['stixel_checklist_count'].values -
ensemble_df2['stixel_checklist_count'].values) != 0


def test_random_state_reproducibility_completely_random_rotation_angle():
model = make_AdaSTEMRegressor()
model = model.set_params(random_state=42, completely_random_rotation=True)
model.split(X, verbosity=1)
ensemble_df1 = model.ensemble_df.copy()

model = model.set_params(random_state=990324, completely_random_rotation=True)
model.split(X, verbosity=1)
ensemble_df2 = model.ensemble_df.copy()

model = model.set_params(random_state=42, completely_random_rotation=True)
model.split(X, verbosity=1)
ensemble_df3 = model.ensemble_df.copy()

# 1 and 3 (with the same random state)
assert ensemble_df1.shape == ensemble_df3.shape
assert np.allclose(ensemble_df1['calibration_point_x_jitter'].values,
ensemble_df3['calibration_point_x_jitter'].values)
assert np.allclose(ensemble_df1['stixel_checklist_count'].values,
ensemble_df3['stixel_checklist_count'].values)

# 1 and 2 (with the different random state)
if ensemble_df1.shape != ensemble_df2.shape:
pass
else:
assert np.sum(ensemble_df1['calibration_point_x_jitter'].values -
ensemble_df2['calibration_point_x_jitter'].values) != 0
assert np.sum(ensemble_df1['stixel_checklist_count'].values -
ensemble_df2['stixel_checklist_count'].values) != 0

0 comments on commit 93ed961

Please sign in to comment.