diff --git a/.mailmap b/.mailmap index c10dde1a38c..bfc5b0172b9 100644 --- a/.mailmap +++ b/.mailmap @@ -283,3 +283,5 @@ Israel Roldan AirvZxf airv_zxf Michael Zingale + +Rudraksh Nalbalwar diff --git a/tardis/io/util.py b/tardis/io/util.py index e0cdc1079c0..cd269fd688b 100644 --- a/tardis/io/util.py +++ b/tardis/io/util.py @@ -138,6 +138,57 @@ def yaml_load_file(filename, loader=yaml.Loader): return yaml.load(stream, Loader=loader) +def parse_species_list(sdec_plotter, data, species_list, packets_mode, nelements=None): + """ + Parse user requested species list and create list of species ids to be used. + + Parameters + ---------- + species_list : list of species to plot + List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. + Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions + (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) + packets_mode : str, optional + Packet mode, either 'virtual' or 'real'. Default is 'virtual'. + nelements : int, optional + Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. + + Raises + ------ + ValueError + If species list contains invalid entries. + + """ + from tardis.util.base import atomic_number2element_symbol + sdec_plotter.parse_species_list(species_list) + _species_list = sdec_plotter._species_list + _species_mapped = sdec_plotter._species_mapped + _keep_colour = sdec_plotter._keep_colour + + if nelements: + interaction_counts = ( + data[packets_mode] + .packets_df_line_interaction["last_line_interaction_species"] + .value_counts() + ) + interaction_counts.index = interaction_counts.index // 100 + element_counts = interaction_counts.groupby( + interaction_counts.index + ).sum() + top_elements = element_counts.nlargest(nelements).index + top_species_list = [ + atomic_number2element_symbol(element) + for element in top_elements + ] + sub_species_list, sub_species_mapped, sub_keep_colour = parse_species_list( + sdec_plotter, data, top_species_list, packets_mode + ) + _species_list = sub_species_list + _species_mapped = sub_species_mapped + _keep_colour = sub_keep_colour + + return _species_list, _species_mapped, _keep_colour + def traverse_configs(base, other, func, *args): """ Recursively traverse a base dict or list along with another one @@ -160,7 +211,7 @@ def traverse_configs(base, other, func, *args): traverse_configs(base[k], other[k], func, *args) elif ( isinstance(base, collections_abc.Iterable) - and not isinstance(base, basestring) + and not isinstance(base, str) and not hasattr(base, "shape") ): for val1, val2 in zip(base, other): diff --git a/tardis/visualization/tools/liv_plot.py b/tardis/visualization/tools/liv_plot.py index 0b88dd975cc..723775d0893 100644 --- a/tardis/visualization/tools/liv_plot.py +++ b/tardis/visualization/tools/liv_plot.py @@ -12,6 +12,7 @@ ) import tardis.visualization.tools.sdec_plot as sdec from tardis.visualization import plot_util as pu +from tardis.io.util import parse_species_list logger = logging.getLogger(__name__) @@ -99,49 +100,6 @@ def from_hdf(cls, hdf_fpath): velocity, ) - def _parse_species_list(self, species_list, packets_mode, nelements=None): - """ - Parse user requested species list and create list of species ids to be used. - - Parameters - ---------- - species_list : list of species to plot - List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. - Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions - (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) - packets_mode : str, optional - Packet mode, either 'virtual' or 'real'. Default is 'virtual'. - nelements : int, optional - Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. - - Raises - ------ - ValueError - If species list contains invalid entries. - - """ - self.sdec_plotter._parse_species_list(species_list) - self._species_list = self.sdec_plotter._species_list - self._species_mapped = self.sdec_plotter._species_mapped - self._keep_colour = self.sdec_plotter._keep_colour - - if nelements: - interaction_counts = ( - self.data[packets_mode] - .packets_df_line_interaction["last_line_interaction_species"] - .value_counts() - ) - interaction_counts.index = interaction_counts.index // 100 - element_counts = interaction_counts.groupby( - interaction_counts.index - ).sum() - top_elements = element_counts.nlargest(nelements).index - top_species_list = [ - atomic_number2element_symbol(element) - for element in top_elements - ] - self._parse_species_list(top_species_list, packets_mode) - def _make_colorbar_labels(self): """ Generate labels for the colorbar based on species. @@ -296,7 +254,13 @@ def _prepare_plot_data( f"{atomic_number2element_symbol(specie // 100)}" for specie in species_in_model ] - self._parse_species_list(species_list, packets_mode, nelements) + self._species_list, self._species_mapped, self._keep_colour = parse_species_list( + sdec_plotter=self.sdec_plotter, + data=self.data, + species_list=species_list, + packets_mode=packets_mode, + nelements=nelements, + ) species_in_model = np.unique( self.data[packets_mode] .packets_df_line_interaction["last_line_interaction_species"] diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py index 2ca12fbb06f..e772de5fcf1 100644 --- a/tardis/visualization/tools/sdec_plot.py +++ b/tardis/visualization/tools/sdec_plot.py @@ -24,7 +24,7 @@ species_string_to_tuple, ) from tardis.visualization import plot_util as pu - +from tardis.io.util import parse_species_list logger = logging.getLogger(__name__) @@ -508,95 +508,7 @@ def from_hdf(cls, hdf_fpath, packets_mode=None): } ) - def _parse_species_list(self, species_list): - """ - Parse user requested species list and create list of species ids to be used. - - Parameters - ---------- - species_list : list of species to plot - List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. - Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions - (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) - - """ - if species_list is not None: - # check if there are any digits in the species list. If there are, then exit. - # species_list should only contain species in the Roman numeral - # format, e.g. Si II, and each ion must contain a space - if any(char.isdigit() for char in " ".join(species_list)) is True: - raise ValueError( - "All species must be in Roman numeral form, e.g. Si II" - ) - else: - full_species_list = [] - species_mapped = {} - for species in species_list: - # check if a hyphen is present. If it is, then it indicates a - # range of ions. Add each ion in that range to the list as a new entry - if "-" in species: - # split the string on spaces. First thing in the list is then the element - element = species.split(" ")[0] - # Next thing is the ion range - # convert the requested ions into numerals - first_ion_numeral = roman_to_int( - species.split(" ")[-1].split("-")[0] - ) - second_ion_numeral = roman_to_int( - species.split(" ")[-1].split("-")[-1] - ) - # add each ion between the two requested into the species list - for ion_number in np.arange( - first_ion_numeral, second_ion_numeral + 1 - ): - full_species_list.append( - f"{element} {int_to_roman(ion_number)}" - ) - else: - # Otherwise it's either an element or ion so just add to the list - full_species_list.append(species) - - # full_species_list is now a list containing each individual species requested - # e.g. it parses species_list = [Si I - V] into species_list = [Si I, Si II, Si III, Si IV, Si V] - self._full_species_list = full_species_list - requested_species_ids = [] - keep_colour = [] - - # go through each of the requested species. Check whether it is - # an element or ion (ions have spaces). If it is an element, - # add all possible ions to the ions list. Otherwise just add - # the requested ion - for species in full_species_list: - if " " in species: - species_id = ( - species_string_to_tuple(species)[0] * 100 - + species_string_to_tuple(species)[1] - ) - requested_species_ids.append([species_id]) - species_mapped[species_id] = [species_id] - else: - atomic_number = element_symbol2atomic_number(species) - species_ids = [ - atomic_number * 100 + ion_number - for ion_number in np.arange(atomic_number) - ] - requested_species_ids.append(species_ids) - species_mapped[atomic_number * 100] = species_ids - # add the atomic number to a list so you know that this element should - # have all species in the same colour, i.e. it was requested like - # species_list = [Si] - keep_colour.append(atomic_number) - requested_species_ids = [ - species_id - for temp_list in requested_species_ids - for species_id in temp_list - ] - self._species_mapped = species_mapped - self._species_list = requested_species_ids - self._keep_colour = keep_colour - else: - self._species_list = None def _calculate_plotting_data( self, packets_mode, packet_wvl_range, distance, nelements @@ -1210,8 +1122,13 @@ def generate_plot_mpl( ) # Parse the requested species list - self._parse_species_list(species_list=species_list) - + self._species_list, self._species_mapped, self._keep_colour = parse_species_list( + sdec_plotter=self, + data=self.data, + species_list=species_list, + packets_mode=packets_mode, + nelements=nelements, + ) # Calculate data attributes required for plotting # and save them in instance itself self._calculate_plotting_data( @@ -1610,8 +1527,13 @@ def generate_plot_ply( ) # Parse the requested species list - self._parse_species_list(species_list=species_list) - + self._species_list, self._species_mapped, self._keep_colour = parse_species_list( + sdec_plotter=self, + data=self.data, + species_list=species_list, + packets_mode=packets_mode, + nelements=nelements, + ) # Calculate data attributes required for plotting # and save them in instance itself self._calculate_plotting_data( diff --git a/tardis/visualization/tools/tests/test_liv_plot.py b/tardis/visualization/tools/tests/test_liv_plot.py index 5a4f56897bc..b13de734167 100644 --- a/tardis/visualization/tools/tests/test_liv_plot.py +++ b/tardis/visualization/tools/tests/test_liv_plot.py @@ -12,7 +12,7 @@ from tardis.io.util import HDFWriterMixin from tardis.visualization.tools.liv_plot import LIVPlotter from tardis.tests.fixtures.regression_data import RegressionData - +from tardis.io.util import parse_species_list class PlotDataHDF(HDFWriterMixin): """ @@ -116,7 +116,7 @@ def test_parse_species_list( attribute, ): """ - Test for the _parse_species_list method in LIVPlotter. + Test for the parse_species_list method in LIVPlotter. Parameters: ----------- @@ -125,11 +125,22 @@ def test_parse_species_list( attribute: The attribute to test after parsing the species list. """ regression_data = RegressionData(request) - plotter._parse_species_list( - packets_mode=self.packets_mode[0], - species_list=self.species_list[0], - nelements=self.nelements[0], + + packets_mode=self.packets_mode[0] + species_list=self.species_list[0] + nelements=self.nelements[0] + + species_list_parsed, species_mapped, keep_colour = parse_species_list( + sdec_plotter=plotter.sdec_plotter, + data=plotter.data, + species_list=species_list, + packets_mode=packets_mode, + nelements=nelements, ) + plotter._species_list = species_list_parsed + plotter._species_mapped = species_mapped + plotter._keep_colour = keep_colour + if attribute == "_species_mapped": plot_object = getattr(plotter, attribute) plot_object = [ diff --git a/tardis/visualization/tools/tests/test_sdec_plot.py b/tardis/visualization/tools/tests/test_sdec_plot.py index 1136ba7e148..8aabf3ebcba 100644 --- a/tardis/visualization/tools/tests/test_sdec_plot.py +++ b/tardis/visualization/tools/tests/test_sdec_plot.py @@ -15,6 +15,7 @@ from tardis.io.util import HDFWriterMixin from tardis.tests.fixtures.regression_data import RegressionData from tardis.visualization.tools.sdec_plot import SDECPlotter +from tardis.io.util import parse_species_list class PlotDataHDF(HDFWriterMixin): @@ -163,7 +164,7 @@ def observed_spectrum(self): ) def test_parse_species_list(self, request, plotter, attribute): """ - Test _parse_species_list method. + Test parse_species_list method. Parameters ---------- @@ -172,7 +173,13 @@ def test_parse_species_list(self, request, plotter, attribute): species : list """ # THIS NEEDS TO BE RUN FIRST. NOT INDEPENDENT TESTS - plotter._parse_species_list(self.species_list[0]) + full_species_list, species_list, keep_colour = parse_species_list(self.species_list[0]) + + # Set the attributes manually on the plotter for testing + plotter._full_species_list = full_species_list + plotter._species_list = species_list + plotter._keep_colour = keep_colour + regression_data = RegressionData(request) data = regression_data.sync_ndarray(getattr(plotter, attribute)) if attribute == "_full_species_list":