Skip to content

Commit

Permalink
Updated Test file
Browse files Browse the repository at this point in the history
  • Loading branch information
rottenstea committed Feb 22, 2024
1 parent 07d73aa commit 25f4f5a
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions test/test_Simulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_neg_dist_apparent_G():
# ----------------------------------------------------------------------------------------------------------------------
@pytest.fixture
def initialized_class_object():

# Generate random data for columns X and Y
np.random.seed(42) # Set seed for reproducibility
isochrone_data = {
Expand Down Expand Up @@ -160,7 +161,7 @@ def test_add_binary_fraction(initialized_class_object):
# Find indices where elements in second_array are different from first_array
non_matching_indices = np.where(obj.abs_mag_incl_plx_binarity != original_abs_mag_incl_plx)[0]
# Calculate the difference between corresponding elements in the two arrays
differences = obj.abs_mag_incl_plx_binarity - original_abs_mag_incl_plx
differences = obj.abs_mag_incl_plx_binarity - original_abs_mag_incl_plx
# Filter differences corresponding to the non-matching indices
non_matching_differences = differences[non_matching_indices]
# Check if the absolute differences are close to 0.753
Expand Down Expand Up @@ -216,6 +217,7 @@ def test_add_field_unallowed_vals(initialized_class_object):


def test_add_field_contamination_sampling(initialized_class_object):

contamination_frac = 0.9
obj = initialized_class_object
obj.set_CMD_type(1)
Expand All @@ -231,6 +233,7 @@ def test_add_field_contamination_sampling(initialized_class_object):


def test_add_field_contamination_conversion(initialized_class_object):

contamination_frac = 0.7
obj = initialized_class_object
obj.set_CMD_type(1)
Expand All @@ -246,6 +249,7 @@ def test_add_field_contamination_conversion(initialized_class_object):


def test_add_field_contamination_merging(initialized_class_object):

contamination_frac = 0.7
obj = initialized_class_object

Expand Down Expand Up @@ -295,7 +299,7 @@ def test_plot_verification_returns_figure_and_axes(initialized_class_object):
# Verify the return types
assert isinstance(fig, plt.Figure)
assert isinstance(axes, np.ndarray)
assert axes.shape == (6,) # Assuming 2x3 subplots
assert axes.shape == (6, ) # Assuming 2x3 subplots


def test_plot_verification_plots_correct_data(initialized_class_object):
Expand Down Expand Up @@ -323,16 +327,14 @@ def test_plot_verification_plots_correct_data(initialized_class_object):
x_plx_uncertainty = obj.cax
y_plx_uncertainty = obj.abs_mag_incl_plx
plx_uncertainty_scatter = ax_plx_uncertainty.collections[0] # Assuming scatter plot is the only collection
assert np.array_equal(plx_uncertainty_scatter.get_offsets(),
np.column_stack((x_plx_uncertainty, y_plx_uncertainty)))
assert np.array_equal(plx_uncertainty_scatter.get_offsets(), np.column_stack((x_plx_uncertainty, y_plx_uncertainty)))

# Test binary subplot
ax_bin_uncertainty = axes[2]
x_bin_uncertainty = obj.cax
y_bin_uncertainty = obj.abs_mag_incl_plx_binarity
bin_uncertainty_scatter = ax_bin_uncertainty.collections[0] # Assuming scatter plot is the only collection
assert np.array_equal(bin_uncertainty_scatter.get_offsets(),
np.column_stack((x_bin_uncertainty, y_bin_uncertainty)))
assert np.array_equal(bin_uncertainty_scatter.get_offsets(), np.column_stack((x_bin_uncertainty, y_bin_uncertainty)))

# Test Av subplot
ax_Av_uncertainty = axes[3]
Expand Down Expand Up @@ -365,4 +367,9 @@ def test_RSS():
e1 = 0.1
e2 = 0.2
expected_result = np.sqrt(e1 ** 2 + e2 ** 2)
assert np.isclose(RSS(e1, e2), expected_result)
assert np.isclose(RSS(e1, e2), expected_result)





0 comments on commit 25f4f5a

Please sign in to comment.