diff --git a/beast/physicsmodel/grid_and_prior_weights.py b/beast/physicsmodel/grid_and_prior_weights.py index e7d7bffd..5105d7cb 100644 --- a/beast/physicsmodel/grid_and_prior_weights.py +++ b/beast/physicsmodel/grid_and_prior_weights.py @@ -174,17 +174,17 @@ def compute_age_mass_metallicity_weights( # compute the mass weights if len(aindxs) > 1: - # cur_masses = _tgrid_single_age["M_ini"] - # deal with repeat masses - happens for MegaBEAST - cur_masses = np.unique(_tgrid_single_age["M_ini"]) - umass_grid_weights = compute_mass_grid_weights(cur_masses) if isinstance(mass_prior_model, dict): mass_prior = PriorMassModel(mass_prior_model) else: mass_prior = mass_prior_model - umass_prior_weights = mass_prior(cur_masses) + + # deal with repeat masses - happens for MegaBEAST + cur_masses = np.unique(_tgrid_single_age["M_ini"]) n_masses = len(_tgrid_single_age["M_ini"]) if len(cur_masses) < n_masses: + umass_grid_weights = compute_mass_grid_weights(cur_masses) + umass_prior_weights = mass_prior(cur_masses) mass_grid_weights = np.zeros(n_masses, dtype=float) mass_prior_weights = np.zeros(n_masses, dtype=float) for k, cmass in enumerate(cur_masses): @@ -192,8 +192,9 @@ def compute_age_mass_metallicity_weights( mass_grid_weights[gvals] = umass_grid_weights[k] mass_prior_weights[gvals] = umass_prior_weights[k] else: - mass_grid_weights = umass_grid_weights - mass_prior_weights = umass_prior_weights + cur_masses = _tgrid_single_age["M_ini"] + mass_grid_weights = compute_mass_grid_weights(cur_masses) + mass_prior_weights = mass_prior(cur_masses) else: # must be a single mass for this age,z combination # set mass weight to zero to remove this point from the grid