Skip to content

Commit

Permalink
Fix BSPlotterProjected.get_elt_projected_plots (#3451)
Browse files Browse the repository at this point in the history
* fix get_elt_projected_plots plotting all element-projected band structures in the same subfigure

* add slightly better tests for get_elt_projected_plots
and get_projected_plots_dots
  • Loading branch information
janosh authored Nov 2, 2023
1 parent 28c8ebb commit 6ce6d9b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
30 changes: 13 additions & 17 deletions pymatgen/electronic_structure/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,11 +979,10 @@ def get_projected_plots_dots(self, dictio, zero_to_efermi=True, ylim=None, vbm_c
if self._bs.is_metal():
e_min = -10
e_max = 10
count = 1

for el in dictio:
for o in dictio[el]:
ax = plt.subplot(fig_rows + fig_cols + count)
for idx, key in enumerate(dictio[el], 1):
ax = plt.subplot(fig_rows + fig_cols + idx)
self._make_ticks(ax)
for b in range(len(data["distances"])):
for i in range(self._nb_bands):
Expand All @@ -1005,14 +1004,14 @@ def get_projected_plots_dots(self, dictio, zero_to_efermi=True, ylim=None, vbm_c
data["distances"][b][j],
data["energy"][str(Spin.down)][b][i][j],
"ro",
markersize=proj[b][str(Spin.down)][i][j][str(el)][o] * 15.0,
markersize=proj[b][str(Spin.down)][i][j][str(el)][key] * 15.0,
)
for j in range(len(data["energy"][str(Spin.up)][b][i])):
ax.plot(
data["distances"][b][j],
data["energy"][str(Spin.up)][b][i][j],
"bo",
markersize=proj[b][str(Spin.up)][i][j][str(el)][o] * 15.0,
markersize=proj[b][str(Spin.up)][i][j][str(el)][key] * 15.0,
)
if ylim is None:
if self._bs.is_metal():
Expand All @@ -1031,19 +1030,18 @@ def get_projected_plots_dots(self, dictio, zero_to_efermi=True, ylim=None, vbm_c
ax.set_ylim(data["vbm"][0][1] + e_min, data["cbm"][0][1] + e_max)
else:
ax.set_ylim(ylim)
ax.set_title(f"{el} {o}")
count += 1
return ax
ax.set_title(f"{el} {key}")
return plt.gcf().axes

@no_type_check
def get_elt_projected_plots(self, zero_to_efermi: bool = True, ylim=None, vbm_cbm_marker: bool = False) -> plt.Axes:
"""Method returning a plot composed of subplots along different elements.
Returns:
a pyplot object with different subfigures for each projection
The blue and red colors are for spin up and spin down
The bigger the red or blue dot in the band structure the higher
character for the corresponding element and orbital
np.ndarray[plt.Axes]: 2x2 array of plt.Axes with different subfigures for each projection
The blue and red colors are for spin up and spin down
The bigger the red or blue dot in the band structure the higher
character for the corresponding element and orbital
"""
band_linewidth = 1.0
proj = self._get_projections_by_branches({e.symbol: ["s", "p", "d"] for e in self._bs.structure.elements})
Expand All @@ -1053,9 +1051,8 @@ def get_elt_projected_plots(self, zero_to_efermi: bool = True, ylim=None, vbm_cb
e_min, e_max = -4, 4
if self._bs.is_metal():
e_min, e_max = -10, 10
count = 1
for el in self._bs.structure.elements:
plt.subplot(220 + count)
for idx, el in enumerate(self._bs.structure.elements, 1):
ax = plt.subplot(220 + idx)
self._make_ticks(ax)
for b in range(len(data["distances"])):
for i in range(self._nb_bands):
Expand Down Expand Up @@ -1119,9 +1116,8 @@ def get_elt_projected_plots(self, zero_to_efermi: bool = True, ylim=None, vbm_cb
else:
ax.set_ylim(ylim)
ax.set_title(str(el))
count += 1

return ax
return axs

def get_elt_projected_plots_color(self, zero_to_efermi=True, elt_ordered=None):
"""Returns a pyplot plot object with one plot where the band structure
Expand Down
13 changes: 10 additions & 3 deletions tests/electronic_structure/test_bandstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,16 @@ def test_basic(self):
assert dict_here["O"]["2p"] == approx(0.015)

def test_proj_bandstructure_plot(self):
# make sure that it can be plotted!
BSPlotterProjected(self.bs_spin).get_elt_projected_plots()
BSPlotterProjected(self.bs_spin).get_projected_plots_dots({"Si": ["3s"]})
axs = BSPlotterProjected(self.bs_spin).get_elt_projected_plots()
assert isinstance(axs, np.ndarray)
assert axs.shape == (2, 2)
assert axs[0, 0].get_title() == "Si"
assert axs[0, 1].get_title() == "O"
assert axs[1, 0].get_title() == ""
axs = BSPlotterProjected(self.bs_spin).get_projected_plots_dots({"Si": ["3s"]})
assert isinstance(axs, list)
assert len(axs) == 1
assert axs[0].get_title() == "Si 3s"

def test_get_branch(self):
branch = self.bs_p.get_branch(0)[0]
Expand Down

0 comments on commit 6ce6d9b

Please sign in to comment.