diff --git a/docs/examples/band-struc.md b/docs/examples/band-struc.md index 43791d5710..f1a31fe814 100644 --- a/docs/examples/band-struc.md +++ b/docs/examples/band-struc.md @@ -27,6 +27,7 @@ pw_diag_thr 1.0e-7 #Parameters (File) init_chg file out_band 1 +out_proj_band 1 #Parameters (Smearing) smearing_method gaussian @@ -58,4 +59,35 @@ points. Run the program, and you will see a file named BANDS_1.dat in the output directory. Plot it to get energy band structure. -[back to top](#band-structure) \ No newline at end of file +If "out_proj_band" set 1, it will also produce the projected band structure in a file called PBAND_1 in xml format. + +The PBAND_1 file starts with number of atomic orbitals in the system, the text contents of element is the same as data in the BANDS_1.dat file, such as: +``` + +1 +153 + +... + +``` + +The rest of the files arranged in sections, each section with a header such as below: + +``` + + +... + + +``` + +The shape of text contents of element is (Number of k-points, Number of bands) + +[back to top](#band-structure) diff --git a/docs/input-main.md b/docs/input-main.md index 055d3a774c..6fb58c74c1 100644 --- a/docs/input-main.md +++ b/docs/input-main.md @@ -32,7 +32,7 @@ - [Variables related to output information](#variables-related-to-output-information) - [out_force](#out_force) | [out_mul](#out_mul) | [out_freq_elec](#out_freq_elec) | [out_freq_ion](#out_freq_ion) | [out_chg](#out_chg) | [out_pot](#out_pot) | [out_dm](#out-dm) | [out_wfc_pw](#out_wfc_pw) | [out_wfc_r](#out_wfc_r) | [out_wfc_lcao](#out_wfc_lcao) | [out_dos](#out-dos) | [out_band](#out-band) | [out_stru](#out-stru) | [out_level](#out_level) | [out_alllog](#out-alllog) | [out_mat_hs](#out_mat_hs) | [out_mat_r](#out_mat_r) | [out_mat_hs2](#out_mat_hs2) | [out_element_info](#out-element-info) | [restart_save](#restart_save) | [restart_load](#restart_load) + [out_force](#out_force) | [out_mul](#out_mul) | [out_freq_elec](#out_freq_elec) | [out_freq_ion](#out_freq_ion) | [out_chg](#out_chg) | [out_pot](#out_pot) | [out_dm](#out-dm) | [out_wfc_pw](#out_wfc_pw) | [out_wfc_r](#out_wfc_r) | [out_wfc_lcao](#out_wfc_lcao) | [out_dos](#out-dos) | [out_band](#out-band) | [out_proj_band](#out-proj-band) | [out_stru](#out-stru) | [out_level](#out_level) | [out_alllog](#out-alllog) | [out_mat_hs](#out_mat_hs) | [out_mat_r](#out_mat_r) | [out_mat_hs2](#out_mat_hs2) | [out_element_info](#out-element-info) | [restart_save](#restart_save) | [restart_load](#restart_load) - [Density of states](#density-of-states) @@ -735,6 +735,12 @@ This part of variables are used to control the output of properties. - **Description**: Controls whether to output the band structure. For mroe information, refer to the [worked example](examples/band-struc.md) - **Default**: 0 +#### out_proj_band + +- **Type**: Integer +- **Description**: Controls whether to output the projected band structure. For mroe information, refer to the [worked example](examples/band-struc.md) +- **Default**: 0 + #### out_stru - **Type**: Boolean diff --git a/source/input.cpp b/source/input.cpp index 5033ba7609..e7eee94d8c 100644 --- a/source/input.cpp +++ b/source/input.cpp @@ -268,6 +268,7 @@ void Input::Default(void) out_wfc_r = 0; out_dos = 0; out_band = 0; + out_proj_band = 0; out_mat_hs = 0; out_mat_hs2 = 0; // LiuXh add 2019-07-15 out_mat_r = 0; // jingan add 2019-8-14 @@ -995,6 +996,10 @@ bool Input::Read(const std::string &fn) { read_value(ifs, out_band); } + else if (strcmp("out_proj_band", word) == 0) + { + read_value(ifs, out_proj_band); + } else if (strcmp("out_mat_hs", word) == 0) { @@ -1941,6 +1946,7 @@ void Input::Bcast() Parallel_Common::bcast_int(out_wfc_r); Parallel_Common::bcast_int(out_dos); Parallel_Common::bcast_int(out_band); + Parallel_Common::bcast_int(out_proj_band); Parallel_Common::bcast_int(out_mat_hs); Parallel_Common::bcast_int(out_mat_hs2); // LiuXh add 2019-07-15 Parallel_Common::bcast_int(out_mat_r); // jingan add 2019-8-14 @@ -2227,6 +2233,7 @@ void Input::Check(void) out_stru = 0; out_dos = 0; out_band = 0; + out_proj_band = 0; cal_force = 0; init_wfc = "file"; init_chg = "atomic"; // useless, @@ -2248,6 +2255,7 @@ void Input::Check(void) out_stru = 0; out_dos = 0; out_band = 0; + out_proj_band = 0; cal_force = 0; init_wfc = "file"; init_chg = "atomic"; diff --git a/source/input.h b/source/input.h index 3336c772e8..54923b8b7f 100644 --- a/source/input.h +++ b/source/input.h @@ -206,6 +206,7 @@ class Input int out_wfc_r; // 0: no; 1: yes int out_dos; // dos calculation. mohan add 20090909 int out_band; // band calculation pengfei 2014-10-13 + int out_proj_band; // projected band structure calculation jiyy add 2022-05-11 int out_mat_hs; // output H matrix and S matrix in local basis. int out_mat_hs2; // LiuXh add 2019-07-16, output H(R) matrix and S(R) matrix in local basis. int out_mat_r; // jingan add 2019-8-14, output r(R) matrix. diff --git a/source/input_conv.cpp b/source/input_conv.cpp index 4a7c5e42f5..d86af8fb31 100644 --- a/source/input_conv.cpp +++ b/source/input_conv.cpp @@ -434,6 +434,7 @@ void Input_Conv::Convert(void) GlobalC::wf.out_wfc_r = INPUT.out_wfc_r; GlobalC::en.out_dos = INPUT.out_dos; GlobalC::en.out_band = INPUT.out_band; + GlobalC::en.out_proj_band = INPUT.out_proj_band; #ifdef __LCAO Local_Orbital_Charge::out_dm = INPUT.out_dm; Pdiag_Double::out_mat_hs = INPUT.out_mat_hs; diff --git a/source/src_io/energy_dos.cpp b/source/src_io/energy_dos.cpp index a0ca1c084c..85884cc78a 100644 --- a/source/src_io/energy_dos.cpp +++ b/source/src_io/energy_dos.cpp @@ -29,7 +29,7 @@ void energy::perform_dos(Local_Orbital_wfc &lowf, LCAO_Hamilt &uhm) const Parallel_Orbitals* pv = uhm.LM->ParaV; - if(out_dos !=0 || out_band !=0) + if(out_dos !=0 || out_band !=0 || out_proj_band !=0) { GlobalV::ofs_running << "\n\n\n\n"; GlobalV::ofs_running << " >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>" << std::endl; @@ -461,35 +461,8 @@ void energy::perform_dos(Local_Orbital_wfc &lowf, LCAO_Hamilt &uhm) out.close(); } - std::string Name_Angular[5][11]; /* decomposed Mulliken charge */ - Name_Angular[0][0] = "s "; - Name_Angular[1][0] = "px "; - Name_Angular[1][1] = "py "; - Name_Angular[1][2] = "pz "; - Name_Angular[2][0] = "d3z^2-r^2 "; - Name_Angular[2][1] = "dxy "; - Name_Angular[2][2] = "dxz "; - Name_Angular[2][3] = "dx^2-y^2 "; - Name_Angular[2][4] = "dyz "; - Name_Angular[3][0] = "f5z^2-3r^2 "; - Name_Angular[3][1] = "f5xz^2-xr^2"; - Name_Angular[3][2] = "f5yz^2-yr^2"; - Name_Angular[3][3] = "fzx^2-zy^2 "; - Name_Angular[3][4] = "fxyz "; - Name_Angular[3][5] = "fx^3-3*xy^2"; - Name_Angular[3][6] = "f3yx^2-y^3 "; - Name_Angular[4][0] = "g1 "; - Name_Angular[4][1] = "g2 "; - Name_Angular[4][2] = "g3 "; - Name_Angular[4][3] = "g4 "; - Name_Angular[4][4] = "g5 "; - Name_Angular[4][5] = "g6 "; - Name_Angular[4][6] = "g7 "; - Name_Angular[4][7] = "g8 "; - Name_Angular[4][8] = "g9 "; - {std::stringstream as; as << GlobalV::global_out_dir << "PDOS"; std::ofstream out(as.str().c_str()); @@ -563,34 +536,7 @@ void energy::perform_dos(Local_Orbital_wfc &lowf, LCAO_Hamilt &uhm) out << "<"<<"/"<<"pdos"<<">" <nw; ++j) - { - const int L1 = atom1->iw2l[j]; - const int N1 = atom1->iw2n[j]; - const int m1 = atom1->iw2m[j]; - out <print_orbital_file(); } delete[] pdos; @@ -675,9 +621,303 @@ void energy::perform_dos(Local_Orbital_wfc &lowf, LCAO_Hamilt &uhm) ss2 << GlobalV::global_out_dir << "BANDS_" << is+1 << ".dat"; GlobalV::ofs_running << "\n Output bands in file: " << ss2.str() << std::endl; Dos::nscf_band(is, ss2.str(), nks, GlobalV::NBANDS, this->ef*0, GlobalC::wf.ekb); + } + }//out_band - } + if(this->out_proj_band) // Projeced band structure added by jiyy-2022-4-20 + { + int nks=0; + if(nspin0==1) + { + nks = GlobalC::kv.nkstot; + } + else if(nspin0==2) + { + nks = GlobalC::kv.nkstot/2; + } + + ModuleBase::ComplexMatrix weightk; + ModuleBase::matrix weight; + int NUM = 0; + if(GlobalV::GAMMA_ONLY_LOCAL) + { + NUM=GlobalV::NLOCAL*GlobalV::NBANDS*nspin0; + weightk.create(nspin0, GlobalV::NBANDS*GlobalV::NLOCAL, true); + weight.create(nspin0, GlobalV::NBANDS*GlobalV::NLOCAL, true); + } + else + { + NUM=GlobalV::NLOCAL*GlobalV::NBANDS*GlobalC::kv.nks; + weightk.create(GlobalC::kv.nks, GlobalV::NBANDS*GlobalV::NLOCAL, true); + weight.create(GlobalC::kv.nks, GlobalV::NBANDS*GlobalV::NLOCAL, true); + } + + for(int is=0; is Mulk; + Mulk.resize(1); + Mulk[0].create(pv->ncol,pv->nrow); + + + ModuleBase::matrix Dwf = lowf.wfc_gamma[is]; + for (int i=0; iSloc.data(), &one_int, &one_int, pv->desc, + Dwf.c, &one_int, &NB, pv->desc, &one_int, + &zero_float, + Mulk[0].c, &one_int, &NB, pv->desc, + &one_int); + #endif + + for (int j=0; jin_this_processor(j,i) ) + { + + const int ir = pv->trace_loc_row[j]; + const int ic = pv->trace_loc_col[i]; + weightk(is, i*GlobalV::NLOCAL+j) = Mulk[0](ic,ir)*lowf.wfc_gamma[is](ic,ir); + } + } + }//ib + }//if + else + { + GlobalV::SEARCH_RADIUS = atom_arrange::set_sr_NL( + GlobalV::ofs_running, + GlobalV::OUT_LEVEL, + GlobalC::ORB.get_rcutmax_Phi(), + GlobalC::ucell.infoNL.get_rcutmax_Beta(), + GlobalV::GAMMA_ONLY_LOCAL); + + atom_arrange::search( + GlobalV::SEARCH_PBC, + GlobalV::ofs_running, + GlobalC::GridD, + GlobalC::ucell, + GlobalV::SEARCH_RADIUS, + GlobalV::test_atom_input);//qifeng-2019-01-21 + + uhm.LM->allocate_HS_R(pv->nnr); + uhm.LM->zeros_HSR('S'); + uhm.genH.calculate_S_no(uhm.LM->SlocR.data()); + uhm.genH.build_ST_new('S', false, GlobalC::ucell, uhm.LM->SlocR.data()); + std::vector Mulk; + Mulk.resize(1); + Mulk[0].create(pv->ncol,pv->nrow); + + + for(int ik=0;ikallocate_HS_k(pv->nloc); + uhm.LM->zeros_HSk('S'); + uhm.LM->folding_fixedH(ik); + + + ModuleBase::ComplexMatrix Dwfc = conj(lowf.wfc_k[ik]); + + for (int i=0; iSloc2.data(), &one_int, &one_int, pv->desc, + Dwfc.c, &one_int, &NB, pv->desc, &one_int, + &zero_float[0], + Mulk[0].c, &one_int, &NB, pv->desc, + &one_int); + #endif + + + for (int j=0; jin_this_processor(j,i) ) + { + + const int ir = pv->trace_loc_row[j]; + const int ic = pv->trace_loc_col[i]; + + weightk(ik, i*GlobalV::NLOCAL+j) = Mulk[0](ic,ir)*lowf.wfc_k[ik](ic,ir); + } + } + + }//ib + + }//if + }//ik +#ifdef __MPI + atom_arrange::delete_vector( + GlobalV::ofs_running, + GlobalV::SEARCH_PBC, + GlobalC::GridD, + GlobalC::ucell, + GlobalV::SEARCH_RADIUS, + GlobalV::test_atom_input); +#endif + }//else + #ifdef __MPI + MPI_Reduce(weightk.real().c, weight.c , NUM , MPI_DOUBLE , MPI_SUM, 0, MPI_COMM_WORLD); + #endif + + if(GlobalV::MY_RANK == 0) + { + std::stringstream ps2; + ps2 << GlobalV::global_out_dir << "PBANDS_" << is+1; + GlobalV::ofs_running << "\n Output projected bands in file: " << ps2.str() << std::endl; + std::ofstream out(ps2.str().c_str()); + + out << "<"<<"pband"<<">" <" << GlobalV::NSPIN<< "<"<<"/"<<"nspin"<<">"<< std::endl; + if (GlobalV::NSPIN==4) + out << "<"<<"norbitals"<<">" <"<< std::endl; + else + out << "<"<<"norbitals"<<">" <"<< std::endl; + out << "<"<<"band_structure nkpoints="<<"\""<" <"<nw; ++j) + { + const int L1 = atom1->iw2l[j]; + const int N1 = atom1->iw2n[j]; + const int m1 = atom1->iw2m[j]; + const int w = GlobalC::ucell.itiaiw2iwt(t, a, j); + + //out << "<"<<"/"<<"energy"<<"_"<<"values"<<">" <" <" <" <" <" <print_orbital_file(); + }//out_proj_band return; } +void energy::print_orbital_file(void) +{ + std::stringstream os; + os<nw; ++j) + { + const int L1 = atom1->iw2l[j]; + const int N1 = atom1->iw2n[j]; + const int m1 = atom1->iw2m[j]; + out < None: - self.fig = fig - self.ax = ax - self._lw = kwargs.pop('lw', 2) - self._bwidth = kwargs.pop('bwdith', 3) - self._label = kwargs.pop('label', None) - self._color = kwargs.pop('color', None) - self._linestyle = kwargs.pop('linestyle', 'solid') - self.plot_params = kwargs + def __init__(self, bandfile: Union[PathLike, Sequence[PathLike]] = None, kptfile: PathLike = '') -> None: + self.bandfile = bandfile + if isinstance(bandfile, list) or isinstance(bandfile, tuple): + self.energy = [] + for file in self.bandfile: + self.k_index, e = self.read(file) + self.energy.append(e) + else: + self.k_index, self.energy = self.read(self.bandfile) + self.energy = np.asarray(self.energy) + self.kptfile = kptfile + self.kpt = None + if self.kptfile: + self.kpt = read_kpt(kptfile) + self.k_index = list(map(int, self.k_index)) + + @classmethod + def read(cls, filename: PathLike): + """Read band data file and return k-points and energy + + :params filename: string of band data file + """ + + data = np.loadtxt(filename, dtype=float) + X, y = np.split(data, (1, ), axis=1) + x = X.flatten() + + return x, y + + @classmethod + def direct_bandgap(cls, vb: namedtuple, cb: namedtuple, klength: int): + """Calculate direct band gap""" + + gap_list = [] + i_index = [] + for i in range(klength): + gap_list.append(np.min(cb.band[:, i])-np.max(vb.band[:, i])) + i_index.append(i) + dgap = np.min(gap_list) + + return dgap, i_index[np.argmin(gap_list)] + + @classmethod + def bandgap(cls, vb: namedtuple, cb: namedtuple): + """Calculate band gap""" + + gap = cb.value-vb.value + + return gap + + @classmethod + def band_type(cls, vb: namedtuple, cb: namedtuple): + vbm_x, cbm_x = vb.k_index, cb.k_index + longone, shortone = (vbm_x, cbm_x) if len( + vbm_x) >= len(cbm_x) else (cbm_x, vbm_x) + for i in shortone: + if i in longone: + btype = "Direct" + else: + btype = "Indirect" + + return btype + + @classmethod + def info(cls, kpath: Sequence, vb: namedtuple, cb: namedtuple): + """Output the information of band structure + + :params kpath: k-points path + :params energy: band energy after subtracting the Fermi level + """ + + gap = cls.bandgap(vb, cb) + dgap, d_i = cls.direct_bandgap(vb, cb, len(kpath)) + btype = cls.band_type(vb, cb) + print( + "--------------------------Band Structure--------------------------", flush=True) + print( + f"{'Band character:'.ljust(30)}{btype}", flush=True) + if btype == "Indirect": + print(f"{'Direct Band gap(eV):'.ljust(30)}{dgap: .4f}", flush=True) + print(f"{'Indirect Band gap(eV):'.ljust(30)}{gap: .4f}", flush=True) + elif btype == "Direct": + print(f"{'Band gap(eV):'.ljust(30)}{gap: .4f}", flush=True) + print(f"{'Band index:'.ljust(30)}{'HOMO'.ljust(10)}{'LUMO'}", flush=True) + print( + f"{''.ljust(30)}{str(vb.band_index[-1]).ljust(10)}{str(cb.band_index[0])}", flush=True) + print(f"{'Eigenvalue of VBM(eV):'.ljust(30)}{vb.value: .4f}", flush=True) + print(f"{'Eigenvalue of CBM(eV):'.ljust(30)}{cb.value: .4f}", flush=True) + vbm_k = np.unique(kpath[vb.k_index], axis=0) + cbm_k = np.unique(kpath[cb.k_index], axis=0) + print( + f"{'Location of VBM'.ljust(30)}{' '.join(list_elem2str(vbm_k[0]))}", flush=True) + for i, j in enumerate(vbm_k): + if i != 0: + print(f"{''.ljust(30)}{' '.join(list_elem2str(j))}", flush=True) + print( + f"{'Location of CBM'.ljust(30)}{' '.join(list_elem2str(cbm_k[0]))}", flush=True) + for i, j in enumerate(cbm_k): + if i != 0: + print(f"{''.ljust(30)}{' '.join(list_elem2str(j))}", flush=True) @classmethod def set_vcband(cls, energy: Sequence) -> Tuple[namedtuple, namedtuple]: @@ -56,20 +151,109 @@ def set_vcband(cls, energy: Sequence) -> Tuple[namedtuple, namedtuple]: return vb, cb + def _shift_energy(self, energy, efermi: float = 0, shift: bool = False): + energy = energy_minus_efermi(energy, efermi) + if shift: + vb, cb = self.set_vcband(energy) + refine_E = np.vstack((vb.band, cb.band)).T + self.info(self.kpt.full_kpath, vb, cb) + else: + refine_E = energy + + return refine_E + @classmethod - def read(cls, filename: PathLike) -> Tuple[np.ndarray, np.ndarray]: - """Read band data file and return k-points and energy + def plot_data(cls, + fig: Figure, + ax: axes.Axes, + x: Sequence, + y: Sequence, + index: Sequence, + efermi: float = 0, + energy_range: Sequence[float] = [], + **kwargs): + """Plot band structure - :params filename: string of band data file + :params x, y: x-axis and y-axis coordinates + :params index: special k-points label and its index in data file + :params efermi: Fermi level in unit eV + :params energy_range: range of energy to plot, its length equals to two """ - data = np.loadtxt(filename) - X, y = np.split(data, (1, ), axis=1) - x = X.flatten() + bandplot = BandPlot(fig, ax, **kwargs) + if not bandplot._color: + bandplot._color = 'black' - return x, y + kpoints, energy = x, y + energy = energy_minus_efermi(energy, efermi) + + bandplot.ax.plot(kpoints, energy, lw=bandplot._lw, color=bandplot._color, + label=bandplot._label, linestyle=bandplot._linestyle) + bandplot._set_figure(index, energy_range) + + def plot(self, + fig: Figure, + ax: axes.Axes, + efermi: Union[float, Sequence[float]] = [], + energy_range: Sequence[float] = [], + shift: bool = True, + **kwargs): + """Plot more than two band structures using data file - def _set_figure(self, index: dict, range: Sequence): + :params efermi: Fermi levels in unit eV, its length equals to `filename` + :params energy_range: range of energy to plot, its length equals to two + :params shift: if sets True, it will calculate band gap. This parameter usually is suitable for semiconductor and insulator. Default: False + """ + + bandplot = BandPlot(fig, ax, **kwargs) + nums = len(self.bandfile) + + if isinstance(self.bandfile, list): + if not efermi: + efermi = [0.0 for i in range(nums)] + if not kwargs.pop('label', None): + bandplot._label = ['' for i in range(nums)] + if not kwargs.pop('color', None): + bandplot._color = ['black' for i in range(nums)] + if not kwargs.pop('linestyle', None): + bandplot._linestyle = ['solid' for i in range(nums)] + + for i, band in enumerate(self.energy): + band = self._shift_energy(band, efermi[i], shift) + bandplot.ax.plot(self.k_index, band, + lw=bandplot._lw, color=bandplot._color[i], label=bandplot._label[i], linestyle=bandplot._linestyle[i]) + + else: + if not efermi: + efermi = 0.0 + + band = self._shift_energy(self.energy, efermi, shift) + bandplot.ax.plot(self.k_index, band, + lw=bandplot._lw, color=bandplot._color, label=bandplot._label, linestyle=bandplot._linestyle) + + if self.kpt: + index = self.kpt.label_special_k + else: + index = self.k_index + bandplot._set_figure(index, energy_range) + + return bandplot + + +class BandPlot: + """Plot band structure""" + + def __init__(self, fig: Figure, ax: axes.Axes, **kwargs) -> None: + self.fig = fig + self.ax = ax + self._lw = kwargs.pop('lw', 2) + self._bwidth = kwargs.pop('bwdith', 3) + self._label = kwargs.pop('label', None) + self._color = kwargs.pop('color', 'black') + self._linestyle = kwargs.pop('linestyle', 'solid') + self.plot_params = kwargs + + def _set_figure(self, index, range: Sequence): """set figure and axes for plotting :params index: dict of label of points of x-axis and its index in data file. Range of x-axis based on index.value() @@ -82,7 +266,7 @@ def _set_figure(self, index: dict, range: Sequence): if isinstance(t, tuple): keys.append(t[0]) values.append(t[1]) - elif isinstance(t, (int, float)): + elif isinstance(t, int): keys.append('') values.append(t) @@ -128,10 +312,11 @@ def _set_figure(self, index: dict, range: Sequence): self.ax.spines['bottom'].set_linewidth(bwidth) # guides - if "grid_params" in self.plot_params.keys(): - self.ax.grid(axis='x', **self.plot_params["grid_params"]) - else: - self.ax.grid(axis='x', lw=1.2) + if '' not in keys: + if "grid_params" in self.plot_params.keys(): + self.ax.grid(axis='x', **self.plot_params["grid_params"]) + else: + self.ax.grid(axis='x', lw=1.2) if "hline_params" in self.plot_params.keys(): self.ax.axhline(0, **self.plot_params["hline_params"]) else: @@ -147,166 +332,432 @@ def _set_figure(self, index: dict, range: Sequence): self.ax.legend(by_label.values(), by_label.keys(), prop={'size': 15}) - def plot(self, x: Sequence, y: Sequence, index: Sequence, efermi: float = 0, energy_range: Sequence[float] = []): - """Plot band structure - - :params x, y: x-axis and y-axis coordinates - :params index: special k-points label and its index in data file - :params efermi: Fermi level in unit eV - :params energy_range: range of energy to plot, its length equals to two - """ - - if not self._color: - self._color = 'black' - kpoints, energy = x, y - energy = energy_minus_efermi(energy, efermi) +class PBand(Band): + def __init__(self, bandfile: Union[PathLike, Sequence[PathLike]] = None, kptfile: str = '') -> None: + self.bandfile = bandfile + if isinstance(bandfile, list) or isinstance(bandfile, tuple): + self.energy = [] + self.orbitals = [] + for file in self.bandfile: + self.nspin, self.norbitals, self.eunit, self.nbands, self.nkpoints, self.k_index, e, orb = self.read( + file) + self._check_energy(e) + self.energy.append(e) + self.orbitals.append(orb) + else: + self.nspin, self.norbitals, self.eunit, self.nbands, self.nkpoints, self.k_index, self.energy, self.orbitals = self.read( + self.bandfile) + self._check_energy(self.energy) + self.energy = np.asarray(self.energy) + self.kptfile = kptfile + self.kpt = None + if self.kptfile: + self.kpt = read_kpt(kptfile) + self.k_index = list(map(int, self.k_index)) + + def _check_energy(self, energy): + assert energy.shape[0] == self.nkpoints, "The dimension of band structure dismatches with the number of k-points." + assert energy.shape[1] == self.nbands, "The dimension of band structure dismatches with the number of bands." + + def _check_weights(self, weights: np.ndarray, prec=1e-5): + assert weights.shape[0] == self.norbitals, "The dimension of weights dismatches with the number of orbitals." + assert weights.shape[1] == self.nkpoints, "The dimension of weights dismatches with the number of k-points." + assert weights.shape[2] == self.nbands, "The dimension of weights dismatches with the number of bands." + one_mat = np.ones((self.nkpoints, self.nbands)) + assert (np.abs(weights.sum(axis=0)-one_mat) < prec).all( + ), f"np.abs(weights.sum(axis=0)-np.ones(({self.nkpoints}, {self.nbands}))) < {prec}" + + @property + def weights(self): + data = np.empty((self.norbitals, self.nkpoints, self.nbands)) + for i, orb in enumerate(self.orbitals): + data[i] = orb['data'] + self._check_weights(data) + return data - self.ax.plot(kpoints, energy, lw=self._lw, color=self._color, - label=self._label, linestyle=self._linestyle) - self._set_figure(index, energy_range) + @classmethod + def read(cls, filename: PathLike): + """Read projected band data file and return k-points, energy and Mulliken weights - def singleplot(self, datafile: PathLike, kptfile: str = '', efermi: float = 0, energy_range: Sequence[float] = [], shift: bool = False): - """Plot band structure using data file + :params bandfile: string of projected band data file + """ - :params datafile: string of band date file - :params kptfile: k-point file - :params efermi: Fermi level in unit eV - :params energy_range: range of energy to plot, its length equals to two - :params shift: if sets True, it will calculate band gap. This parameter usually is suitable for semiconductor and insulator. Default: False + from lxml import etree + pbanddata = etree.parse(filename) + root = pbanddata.getroot() + nspin = int(root.xpath('//nspin')[0].text.replace(' ', '')) + norbitals = int(root.xpath('//norbitals') + [0].text.replace(' ', '')) + eunit = root.xpath('//band_structure/@units')[0].replace(' ', '') + nbands = int(root.xpath('//band_structure/@nbands') + [0].replace(' ', '')) + nkpoints = int(root.xpath('//band_structure/@nkpoints') + [0].replace(' ', '')) + k_index = np.arange(nkpoints) + energy = root.xpath('//band_structure')[0].text.split('\n') + energy = handle_data(energy) + remove_empty(energy) + energy = np.asarray(energy, dtype=float) + + orbitals = [] + for i in range(norbitals): + orb = OrderedDict() + o_index_str = root.xpath( + '//orbital/@index')[i] + orb['index'] = int(o_index_str.replace(' ', '')) + orb['atom_index'] = int(root.xpath( + '//orbital/@atom_index')[i].replace(' ', '')) + orb['species'] = root.xpath( + '//orbital/@species')[i].replace(' ', '') + orb['l'] = int(root.xpath('//orbital/@l')[i].replace(' ', '')) + orb['m'] = int(root.xpath('//orbital/@m')[i].replace(' ', '')) + orb['z'] = int(root.xpath('//orbital/@z')[i].replace(' ', '')) + data = root.xpath('//data')[i].text.split('\n') + data = handle_data(data) + remove_empty(data) + orb['data'] = np.asarray(data, dtype=float) + orbitals.append(orb) + + return nspin, norbitals, eunit, nbands, nkpoints, k_index, energy, orbitals + + def _write(self, species: Union[Sequence[Any], Dict[Any, List[int]], Dict[Any, Dict[int, List[int]]]], keyname='', file_dir:PathLike=''): + """Write parsed projected bands data to files + + Args: + orbital (dict): parsed data + species (Union[Sequence[Any], Dict[Any, List[int]], Dict[Any, Dict[int, List[int]]]], optional): list of atomic species(index or atom index) or dict of atomic species(index or atom index) and its angular momentum list. Defaults to []. + keyname (str): the keyword that extracts the PBAND. Allowed values: 'index', 'atom_index', 'species' """ - kpt = read_kpt(kptfile) + band, totnum = parse_projected_data(self.orbitals, species, keyname) + + if isinstance(species, (list, tuple)): + for elem in band.keys(): + header_list = [''] + with open(file_dir/f"{keyname}-{elem}.dat", 'w') as f: + header_list.append( + f"Projected band structure for {keyname}: {elem}") + header_list.append('') + header_list.append( + f'\tNumber of k-points: {self.nkpoints}') + header_list.append(f'\tNumber of bands: {self.nbands}') + header_list.append('') + for orb in self.orbitals: + if orb[keyname] == elem: + header_list.append( + f"\tAdd data for index ={orb['index']:4d}, atom_index ={orb['atom_index']:4d}, element ={orb['species']:4s}, l,m,z={orb['l']:3d}, {orb['m']:3d}, {orb['z']:3d}") + header_list.append('') + header_list.append( + f'Data shape: ({self.nkpoints}, {self.nbands})') + header_list.append('') + header = '\n'.join(header_list) + np.savetxt(f, band[elem], header=header) + + elif isinstance(species, dict): + for elem in band.keys(): + elem_file_dir = file_dir/f"{keyname}-{elem}" + elem_file_dir.mkdir(exist_ok=True) + for ang in band[elem].keys(): + l_index = int(ang) + if isinstance(band[elem][ang], dict): + for mag in band[elem][ang].keys(): + header_list = [''] + m_index = int(mag) + with open(elem_file_dir/f"{keyname}-{elem}_{ang}_{mag}.dat", 'w') as f: + header_list.append( + f"Projected band structure for {keyname}: {elem}") + header_list.append('') + header_list.append( + f'\tNumber of k-points: {self.nkpoints}') + header_list.append( + f'\tNumber of bands: {self.nbands}') + header_list.append('') + for orb in self.orbitals: + if orb[keyname] == elem and orb["l"] == l_index and orb["m"] == m_index: + header_list.append( + f"\tAdd data for index ={orb['index']:4d}, atom_index ={orb['atom_index']:4d}, element ={orb['species']:4s}, l,m,z={orb['l']:3d}, {orb['m']:3d}, {orb['z']:3d}") + header_list.append('') + header_list.append( + f'Data shape: ({self.nkpoints}, {self.nbands})') + header_list.append('') + header = '\n'.join(header_list) + np.savetxt(f, band[elem][ang] + [mag], header=header) + + else: + header_list = [''] + with open(elem_file_dir/f"{keyname}-{elem}_{ang}.dat", 'w') as f: + header_list.append( + f"Projected band structure for {keyname}: {elem}") + header_list.append('') + header_list.append( + f'\tNumber of k-points: {self.nkpoints}') + header_list.append( + f'\tNumber of bands: {self.nbands}') + header_list.append('') + for orb in self.orbitals: + if orb[keyname] == elem and orb["l"] == l_index: + header_list.append( + f"\tAdd data for index ={orb['index']:4d}, atom_index ={orb['atom_index']:4d}, element ={orb['species']:4s}, l,m,z={orb['l']:3d}, {orb['m']:3d}, {orb['z']:3d}") + header_list.append('') + header_list.append( + f'Data shape: ({self.nkpoints}, {self.nbands})') + header_list.append('') + header = '\n'.join(header_list) + np.savetxt(f, band[elem][ang], header=header) + + def write(self, + index: Union[Sequence[int], Dict[int, List[int]], + Dict[int, Dict[int, List[int]]]] = [], + atom_index: Union[Sequence[int], Dict[int, List[int]], + Dict[int, Dict[int, List[int]]]] = [], + species: Union[Sequence[str], Dict[str, List[int]], + Dict[str, Dict[int, List[int]]]] = [], + outdir: PathLike = './' + ): + """Write parsed partial dos data to files + + Args: + index (Union[Sequence[int], Dict[int, List[int]], Dict[int, Dict[int, List[int]]]], optional): extract PDOS of each atom. Defaults to []. + atom_index (Union[Sequence[int], Dict[int, List[int]], Dict[int, Dict[int, List[int]]]], optional): extract PDOS of each atom with same atom_index. Defaults to []. + species (Union[Sequence[str], Dict[str, List[int]], Dict[str, Dict[int, List[int]]]], optional): extract PDOS of each atom with same species. Defaults to []. + outdir (PathLike, optional): directory of parsed PDOS files. Defaults to './'. + """ - if not self._color: - self._color = 'black' + if isinstance(self.bandfile, list): + for i in range(len(self.bandfile)): + file_dir = Path(f"{outdir}", f"PBAND{i}_FILE") + file_dir.mkdir(exist_ok=True) + if index: + self._write(index, keyname='index', + file_dir=file_dir) + if atom_index: + self._write(atom_index, keyname='atom_index', + file_dir=file_dir) + if species: + self._write(species, keyname='species', + file_dir=file_dir) - kpoints, energy = self.read(datafile) - energy = energy_minus_efermi(energy, efermi) - if shift: - vb, cb = self.set_vcband(energy) - self.ax.plot(kpoints, np.vstack((vb.band, cb.band)).T, - lw=self._lw, color=self._color, label=self._label, linestyle=self._linestyle) - self.info(kpt.full_kpath, vb, cb) else: - self.ax.plot(kpoints, energy, - lw=self._lw, color=self._color, label=self._label, linestyle=self._linestyle) - index = kpt.label_special_k - self._set_figure(index, energy_range) - - def multiplot(self, datafile: Sequence[PathLike], kptfile: str = '', efermi: Sequence[float] = [], energy_range: Sequence[float] = [], shift: bool = True): - """Plot more than two band structures using data file - - :params datafile: list of path of band date file - :params kptfile: k-point file - :params efermi: list of Fermi levels in unit eV, its length equals to `filename` - :params energy_range: range of energy to plot, its length equals to two - :params shift: if sets True, it will calculate band gap. This parameter usually is suitable for semiconductor and insulator. Default: False + file_dir = Path(f"{outdir}", f"PBAND{1}_FILE") + file_dir.mkdir(exist_ok=True) + if index: + self._write(index, keyname='index', + file_dir=file_dir) + if atom_index: + self._write(atom_index, keyname='atom_index', + file_dir=file_dir) + if species: + self._write(species, keyname='species', + file_dir=file_dir) + + def _plot(self, + fig: Figure, + ax: axes.Axes, + energy: np.ndarray, + species: Union[Sequence[Any], Dict[Any, List[int]], + Dict[Any, Dict[int, List[int]]]] = [], + efermi: float = 0, + energy_range: Sequence[float] = [], + shift: bool = False, + keyname: str = '', + outdir: PathLike = './', + out_index: int = 1, + cmap='jet', + **kwargs): + """Plot parsed projected bands data + + Args: + fig (Figure): object of matplotlib.figure.Figure + ax (Union[axes.Axes, Sequence[axes.Axes]]): object of matplotlib.axes.Axes or a list of this objects + species (Union[Sequence[Any], Dict[Any, List[int]], Dict[Any, Dict[int, List[int]]]], optional): list of atomic species(index or atom index) or dict of atomic species(index or atom index) and its angular momentum list. Defaults to []. + efermi (float, optional): fermi level in unit eV. Defaults to 0. + energy_range (Sequence[float], optional): energy range in unit eV for plotting. Defaults to []. + shift (bool, optional): if shift energy by fermi level and set the VBM to zero, or not. Defaults to False. + keyname (str, optional): the keyword that extracts the PBANDS. Defaults to ''. + + Returns: + BandPlot object: for manually plotting picture with bandplot.ax """ - kpt = read_kpt(kptfile) - - if not efermi: - efermi = [0.0 for i in range(len(datafile))] - if not self._label: - self._label = ['' for i in range(len(datafile))] - if not self._color: - self._color = ['black' for i in range(len(datafile))] - if not self._linestyle: - self._linestyle = ['solid' for i in range(len(datafile))] - - emin = -np.inf - emax = np.inf - for i, file in enumerate(datafile): - kpoints, energy = self.read(file) - if shift: - vb, cb = self.set_vcband( - energy_minus_efermi(energy, efermi[i])) - energy_min = np.min(vb.band) - energy_max = np.max(cb.band) - if energy_min > emin: - emin = energy_min - if energy_max < emax: - emax = energy_max - - self.ax.plot(kpoints, np.vstack((vb.band, cb.band)).T, - lw=self._lw, color=self._color[i], label=self._label[i], linestyle=self._linestyle[i]) - self.info(kpt.full_kpath, vb, cb) - else: - self.ax.plot(kpoints, energy_minus_efermi(energy, efermi[i]), - lw=self._lw, color=self._color[i], label=self._label[i], linestyle=self._linestyle[i]) - - index = kpt.label_special_k - self._set_figure(index, energy_range) - - @classmethod - def bandgap(cls, vb: namedtuple, cb: namedtuple): - """Calculate band gap""" - - gap = cb.value-vb.value + def _seg_plot(bandplot, lc, index, file_dir, name): + cbar = bandplot.fig.colorbar(lc, ax=bandplot.ax) + bandplot._set_figure(index, energy_range) + bandplot.fig.savefig(file_dir/f'{keyname}-{bandplot._label}.pdf', dpi=400) + cbar.remove() + plt.cla() - return gap + return bandplot - @classmethod - def info(cls, kpath: Sequence, vb: namedtuple, cb: namedtuple): - """Output the information of band structure + wei, totnum = parse_projected_data(self.orbitals, species, keyname) + energy = self._shift_energy(energy, efermi, shift) + file_dir = Path(f"{outdir}", f"PBAND{out_index}_FIG") + file_dir.mkdir(exist_ok=True) - :params kpath: k-points path - :params energy: band energy after subtracting the Fermi level + if self.kpt: + index = self.kpt.label_special_k + else: + index = self.k_index + + if not species: + bandplot = BandPlot(fig, ax, **kwargs) + bandplot = super().plot(fig, ax, efermi, energy_range, shift, **kwargs) + bandplot._set_figure(index, energy_range) + + return bandplot + + if isinstance(species, (list, tuple)): + bandplots = [] + for i, elem in enumerate(wei.keys()): + bandplot = BandPlot(fig, ax, **kwargs) + bandplot._label = elem + for ib in range(self.nbands): + points = np.array((self.k_index, energy[0:, ib])).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + norm = Normalize(vmin=wei[elem][0:, ib].min(), vmax=wei[elem][0:, ib].max()) + lc = LineCollection(segments, cmap=plt.get_cmap(cmap), norm=norm) + lc.set_array(wei[elem][0:, ib]) + lc.set_label(bandplot._label) + bandplot.ax.add_collection(lc) + + _seg_plot(bandplot, lc, index, file_dir, name=f'{elem}') + bandplots.append(bandplot) + return bandplots + + elif isinstance(species, dict): + bandplots = [] + for i, elem in enumerate(wei.keys()): + elem_file_dir = file_dir/f"{keyname}-{elem}" + elem_file_dir.mkdir(exist_ok=True) + for ang in wei[elem].keys(): + l_index = int(ang) + if isinstance(wei[elem][ang], dict): + for mag in wei[elem][ang].keys(): + bandplot = BandPlot(fig, ax, **kwargs) + m_index = int(mag) + bandplot._label = f"{elem}-{get_angular_momentum_name(l_index, m_index)}" + for ib in range(self.nbands): + points = np.array((self.k_index, energy[0:, ib])).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + norm = Normalize(vmin=wei[elem][ang][mag][0:, ib].min(), vmax=wei[elem][ang][mag][0:, ib].max()) + lc = LineCollection(segments, cmap=plt.get_cmap(cmap), norm=norm) + lc.set_array(wei[elem][ang][mag][0:, ib]) + lc.set_label(bandplot._label) + bandplot.ax.add_collection(lc) + + _seg_plot(bandplot, lc, index, elem_file_dir, name=f'{elem}_{ang}_{mag}') + bandplots.append(bandplot) + + else: + bandplot = BandPlot(fig, ax, **kwargs) + bandplot._label = f"{elem}-{get_angular_momentum_label(l_index)}" + for ib in range(self.nbands): + points = np.array((self.k_index, energy[0:, ib])).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + norm = Normalize(vmin=wei[elem][ang][0:, ib].min(), vmax=wei[elem][ang][0:, ib].max()) + lc = LineCollection(segments, cmap=plt.get_cmap(cmap), norm=norm) + lc.set_array(wei[elem][ang][0:, ib]) + lc.set_label(bandplot._label) + bandplot.ax.add_collection(lc) + + _seg_plot(bandplot, lc, index, elem_file_dir, name=f'{elem}_{ang}') + bandplots.append(bandplot) + + return bandplots + + plt.clf() + + def plot(self, + fig: Figure, + ax: Union[axes.Axes, Sequence[axes.Axes]], + index: Union[Sequence[int], Dict[int, List[int]], + Dict[int, Dict[int, List[int]]]] = [], + atom_index: Union[Sequence[int], Dict[int, List[int]], + Dict[int, Dict[int, List[int]]]] = [], + species: Union[Sequence[str], Dict[str, List[int]], + Dict[str, Dict[int, List[int]]]] = [], + efermi: Union[float, Sequence[float]] = [], + energy_range: Sequence[float] = [], + shift: bool = False, + outdir: PathLike = './', + cmapname='jet', + **kwargs): + """Plot parsed projected band data + + Args: + fig (Figure): object of matplotlib.figure.Figure + ax (Union[axes.Axes, Sequence[axes.Axes]]): object of matplotlib.axes.Axes or a list of this objects + index (Union[Sequence[int], Dict[int, List[int]], Dict[int, Dict[int, List[int]]]], optional): extract PBAND of each atom. Defaults to []. + atom_index (Union[Sequence[int], Dict[int, List[int]], Dict[int, Dict[int, List[int]]]], optional): extract PBAND of each atom with same atom_index. Defaults to []. + species (Union[Sequence[str], Dict[str, List[int]], Dict[str, Dict[int, List[int]]]], optional): extract PBAND of each atom with same species. Defaults to []. + efermi (float, optional): fermi level in unit eV. Defaults to 0. + energy_range (Sequence[float], optional): energy range in unit eV for plotting. Defaults to []. + shift (bool, optional): if shift energy by fermi level and set the VBM to zero, or not. Defaults to False. + outdir (PathLike): Default: './' + cmapname (str): Default: 'jet' + + Returns: + BandPlot object: for manually plotting picture with bandplot.ax """ + nums = len(self.bandfile) + + if isinstance(self.bandfile, list): + if not efermi: + efermi = [0.0 for i in range(nums)] + _linestyle = kwargs.pop( + 'linestyle', ['solid' for i in range(nums)]) + + for i, band in enumerate(self.energy): + if not index and not atom_index and not species: + bandplot = self._plot(fig=fig, ax=ax, energy=band, species=[ + ], efermi=efermi[i], energy_range=energy_range, shift=shift, keyname='', linestyle=_linestyle[i], outdir=outdir, out_index=i, cmapname=cmapname, **kwargs) + if index: + bandplot = self._plot(fig=fig, ax=ax, energy=band, species=index, efermi=efermi[i], + energy_range=energy_range, shift=shift, keyname='index', linestyle=_linestyle[i], outdir=outdir, out_index=i, cmapname=cmapname, **kwargs) + if atom_index: + bandplot = self._plot(fig=fig, ax=ax, energy=band, species=atom_index, efermi=efermi[i], + energy_range=energy_range, shift=shift, keyname='atom_index', linestyle=_linestyle[i], outdir=outdir, out_index=i, cmapname=cmapname, **kwargs) + if species: + bandplot = self._plot(fig=fig, ax=ax, energy=band, species=species, efermi=efermi[i], + energy_range=energy_range, shift=shift, keyname='species', linestyle=_linestyle[i], outdir=outdir, out_index=i, cmapname=cmapname, **kwargs) - def band_type(vbm_x, cbm_x): - longone, shortone = (vbm_x, cbm_x) if len( - vbm_x) >= len(cbm_x) else (cbm_x, vbm_x) - for i in shortone: - if i in longone: - btype = "Direct" - else: - btype = "Indirect" - return btype - - gap = cls.bandgap(vb, cb) - print( - "--------------------------Band Structure--------------------------", flush=True) - print( - f"{'Band character:'.ljust(30)}{band_type(vb.k_index, cb.k_index)}", flush=True) - print(f"{'Band gap(eV):'.ljust(30)}{gap: .4f}", flush=True) - print(f"{'Band index:'.ljust(30)}{'HOMO'.ljust(10)}{'LUMO'}", flush=True) - print( - f"{''.ljust(30)}{str(vb.band_index[-1]).ljust(10)}{str(cb.band_index[0])}", flush=True) - print(f"{'Eigenvalue of VBM(eV):'.ljust(30)}{vb.value: .4f}", flush=True) - print(f"{'Eigenvalue of CBM(eV):'.ljust(30)}{cb.value: .4f}", flush=True) - vbm_k = np.unique(kpath[vb.k_index], axis=0) - cbm_k = np.unique(kpath[cb.k_index], axis=0) - print( - f"{'Location of VBM'.ljust(30)}{' '.join(list_elem2str(vbm_k[0]))}", flush=True) - for i, j in enumerate(vbm_k): - if i != 0: - print(f"{''.ljust(30)}{' '.join(list_elem2str(j))}", flush=True) - print( - f"{'Location of CBM'.ljust(30)}{' '.join(list_elem2str(cbm_k[0]))}", flush=True) - for i, j in enumerate(cbm_k): - if i != 0: - print(f"{''.ljust(30)}{' '.join(list_elem2str(j))}", flush=True) + else: + if not index and not atom_index and not species: + bandplot = self._plot(fig=fig, ax=ax, energy=self.energy, species=[ + ], efermi=efermi, energy_range=energy_range, shift=shift, keyname='', outdir=outdir, out_index=1, **kwargs) + if index: + bandplot = self._plot(fig=fig, ax=ax, energy=self.energy, species=index, efermi=efermi, + energy_range=energy_range, shift=shift, keyname='index', outdir=outdir, out_index=1, **kwargs) + if atom_index: + bandplot = self._plot(fig=fig, ax=ax, energy=self.energy, species=atom_index, efermi=efermi, + energy_range=energy_range, shift=shift, keyname='atom_index', outdir=outdir, out_index=1, **kwargs) + if species: + bandplot = self._plot(fig=fig, ax=ax, energy=self.energy, species=species, efermi=efermi, + energy_range=energy_range, shift=shift, keyname='species', outdir=outdir, out_index=1, **kwargs) + + return bandplot if __name__ == "__main__": import matplotlib.pyplot as plt from pathlib import Path - parent = Path(r"D:\ustc\TEST\HOIP\double HOIP\result\bond") - name = "CsAgBiBr" + parent = Path(r"C:\Users\YY.Ji\Desktop") + name = "PBANDS_1" path = parent/name - notes = {'s': '(b)'} - datafile = [path/"soc.dat", path/"non-soc.dat"] - kptfile = path/"KPT" - fig, ax = plt.subplots(figsize=(12, 12)) - label = ["with SOC", "without SOC"] - color = ["r", "g"] - linestyle = ["solid", "dashed"] - band = BandPlot(fig, ax, notes=notes, label=label, - color=color, linestyle=linestyle) + # notes = {'s': '(b)'} + # datafile = [path/"soc.dat", path/"non-soc.dat"] + # kptfile = path/"KPT" + fig, ax = plt.subplots(figsize=(12, 6)) + # label = ["with SOC", "without SOC"] + # color = ["r", "g"] + # linestyle = ["solid", "dashed"] energy_range = [-5, 6] - efermi = [4.417301755850272, 4.920435541999894] - shift = True - band.multiplot(datafile, kptfile, efermi, energy_range, shift) - fig.savefig("band.png") + efermi = 4.417301755850272 + shift = False + #species = {"Ag": [2], "Cl": [1], "In": [0]} + atom_index = {8: {2: [1, 2]}, 4: {2: [1, 2]}, 10: [1, 2]} + pband = PBand(str(path)) + pband.plot(fig, ax, atom_index=atom_index, efermi=efermi, + energy_range=energy_range, shift=shift) + pband.write(atom_index=atom_index) diff --git a/tools/plot-tools/abacus_plot/dos.py b/tools/plot-tools/abacus_plot/dos.py index a289f78641..da1296008f 100644 --- a/tools/plot-tools/abacus_plot/dos.py +++ b/tools/plot-tools/abacus_plot/dos.py @@ -5,16 +5,16 @@ Mail: jiyuyang@mail.ustc.edu.cn, 1041176461@qq.com ''' -from collections import OrderedDict, defaultdict, namedtuple +from collections import OrderedDict, namedtuple import numpy as np from os import PathLike from pathlib import Path -from typing import Dict, List, Sequence, Tuple, Union +from typing import Dict, List, Sequence, Tuple, Union, Any from matplotlib.figure import Figure from matplotlib import axes from abacus_plot.utils import (energy_minus_efermi, get_angular_momentum_label, - get_angular_momentum_name, remove_empty) + get_angular_momentum_name, remove_empty, handle_data, parse_projected_data) class DOS: @@ -92,14 +92,14 @@ def bandgap(cls, vb: namedtuple, cb: namedtuple): class DOSPlot: """Plot density of state(DOS)""" - def __init__(self, fig: Figure, ax: axes.Axes, **kwargs) -> None: + def __init__(self, fig: Figure = None, ax: axes.Axes = None, **kwargs) -> None: self.fig = fig self.ax = ax self._lw = kwargs.pop('lw', 2) self._bwidth = kwargs.pop('bwdith', 3) self.plot_params = kwargs - def _set_figure(self, energy_range: Sequence, dos_range: Sequence, notes: Dict = {}): + def _set_figure(self, energy_range: Sequence = [], dos_range: Sequence = [], notes: Dict = {}): """set figure and axes for plotting :params energy_range: range of energy @@ -154,7 +154,7 @@ def _set_figure(self, energy_range: Sequence, dos_range: Sequence, notes: Dict = class TDOS(DOS): """Parse total DOS data""" - def __init__(self, tdosfile: PathLike) -> None: + def __init__(self, tdosfile: PathLike=None) -> None: super().__init__() self.tdosfile = tdosfile self._read() @@ -202,7 +202,7 @@ def plot(self, fig: Figure, ax: Union[axes.Axes, Sequence[axes.Axes]], efermi: f class PDOS(DOS): """Parse partial DOS data""" - def __init__(self, pdosfile: PathLike) -> None: + def __init__(self, pdosfile: PathLike=None) -> None: super().__init__() self.pdosfile = pdosfile self._read() @@ -213,15 +213,6 @@ def _read(self): :params pdosfile: string of PDOS data file """ - def handle_data(data): - data.remove('') - - def handle_elem(elem): - elist = elem.split(' ') - remove_empty(elist) # `list` will be modified in function - return elist - return list(map(handle_elem, data)) - from lxml import etree pdosdata = etree.parse(self.pdosfile) root = pdosdata.getroot() @@ -257,74 +248,15 @@ def _all_sum(self) -> Tuple[np.ndarray, int]: res = res + orb['data'] return res - def parse(self, species: Union[Sequence[str], Dict[str, List[int]], Dict[str, Dict[str, List[int]]]]): - """Extract partial dos from file - - Args: - species (Union[Sequence[str], Dict[str, List[int]], Dict[str, Dict[str, List[int]]]], optional): list of atomic species or dict of atomic species and its angular momentum list. Defaults to []. - """ - - if isinstance(species, (list, tuple)): - dos = {} - elements = species - for elem in elements: - count = 0 - dos_temp = np.zeros_like(self.orbitals[0]["data"], dtype=float) - for orb in self.orbitals: - if orb["species"] == elem: - dos_temp += orb["data"] - count += 1 - if count: - dos[elem] = dos_temp - - return dos - - elif isinstance(species, dict): - dos = defaultdict(dict) - elements = list(species.keys()) - l = list(species.values()) - for i, elem in enumerate(elements): - if isinstance(l[i], dict): - for ang, mag in l[i].items(): - l_count = 0 - l_index = int(ang) - l_dos = {} - for m_index in mag: - m_count = 0 - dos_temp = np.zeros_like( - self.orbitals[0]["data"], dtype=float) - for orb in self.orbitals: - if orb["species"] == elem and orb["l"] == l_index and orb["m"] == m_index: - dos_temp += orb["data"] - m_count += 1 - l_count += 1 - if m_count: - l_dos[m_index] = dos_temp - if l_count: - dos[elem][l_index] = l_dos - - elif isinstance(l[i], list): - for l_index in l[i]: - count = 0 - dos_temp = np.zeros_like( - self.orbitals[0]["data"], dtype=float) - for orb in self.orbitals: - if orb["species"] == elem and orb["l"] == l_index: - dos_temp += orb["data"] - count += 1 - if count: - dos[elem][l_index] = dos_temp - - return dos - - def write(self, species: Union[Sequence[str], Dict[str, List[int]], Dict[str, Dict[str, List[int]]]], outdir: PathLike = './'): + def _write(self, species: Union[Sequence[Any], Dict[Any, List[int]], Dict[Any, Dict[int, List[int]]]], keyname='', outdir: PathLike = './'): """Write parsed partial dos data to files Args: - species (Union[Sequence[str], Dict[str, List[int]], Dict[str, Dict[str, List[int]]]], optional): list of atomic species or dict of atomic species and its angular momentum list. Defaults to []. + species (Union[Sequence[Any], Dict[Any, List[int]], Dict[Any, Dict[int, List[int]]]], optional): list of atomic species(index or atom index) or dict of atomic species(index or atom index) and its angular momentum list. Defaults to []. + keyname (str): the keyword that extracts the PDOS. Allowed values: 'index', 'atom_index', 'species' """ - dos = self.parse(species) + dos, totnum = parse_projected_data(self.orbitals, species, keyname) fmt = ['%13.7f', '%15.8f'] if self._nsplit == 1 else [ '%13.7f', '%15.8f', '%15.8f'] file_dir = Path(f"{outdir}", "PDOS_FILE") @@ -334,14 +266,14 @@ def write(self, species: Union[Sequence[str], Dict[str, List[int]], Dict[str, Di for elem in dos.keys(): header_list = [''] data = np.hstack((self.energy.reshape(-1, 1), dos[elem])) - with open(file_dir/f"{elem}.dat", 'w') as f: + with open(file_dir/f"{keyname}-{elem}.dat", 'w') as f: header_list.append( - f"\tpartial DOS for atom species: {elem}") + f"Partial DOS for {keyname}: {elem}") header_list.append('') for orb in self.orbitals: - if orb["species"] == elem: + if orb[keyname] == elem: header_list.append( - f"\tAdd data for atom_index ={orb['atom_index']:4d}, l,m,z={orb['l']:3d}, {orb['m']:3d}, {orb['z']:3d}") + f"\tAdd data for index ={orb['index']:4d}, atom_index ={orb['atom_index']:4d}, element ={orb['species']:4s}, l,m,z={orb['l']:3d}, {orb['m']:3d}, {orb['z']:3d}") header_list.append('') header_list.append('\tEnergy'+10*' ' + 'spin 1'+8*' '+'spin 2') @@ -351,7 +283,7 @@ def write(self, species: Union[Sequence[str], Dict[str, List[int]], Dict[str, Di elif isinstance(species, dict): for elem in dos.keys(): - elem_file_dir = file_dir/f"{elem}" + elem_file_dir = file_dir/f"{keyname}-{elem}" elem_file_dir.mkdir(exist_ok=True) for ang in dos[elem].keys(): l_index = int(ang) @@ -361,14 +293,14 @@ def write(self, species: Union[Sequence[str], Dict[str, List[int]], Dict[str, Di data = np.hstack( (self.energy.reshape(-1, 1), dos[elem][ang][mag])) m_index = int(mag) - with open(elem_file_dir/f"{elem}_{ang}_{mag}.dat", 'w') as f: + with open(elem_file_dir/f"{keyname}-{elem}_{ang}_{mag}.dat", 'w') as f: header_list.append( - f"\tpartial DOS for atom species: {elem}") + f"Partial DOS for {keyname}: {elem}") header_list.append('') for orb in self.orbitals: - if orb["species"] == elem and orb["l"] == l_index and orb["m"] == m_index: + if orb[keyname] == elem and orb["l"] == l_index and orb["m"] == m_index: header_list.append( - f"\tAdd data for atom_index ={orb['atom_index']:4d}, l,m,z={orb['l']:3d}, {orb['m']:3d}, {orb['z']:3d}") + f"\tAdd data for index ={orb['index']:4d}, atom_index ={orb['atom_index']:4d}, element ={orb['species']:4s}, l,m,z={orb['l']:3d}, {orb['m']:3d}, {orb['z']:3d}") header_list.append('') header_list.append( '\tEnergy'+10*' '+'spin 1'+8*' '+'spin 2') @@ -380,14 +312,14 @@ def write(self, species: Union[Sequence[str], Dict[str, List[int]], Dict[str, Di header_list = [''] data = np.hstack( (self.energy.reshape(-1, 1), dos[elem][ang])) - with open(elem_file_dir/f"{elem}_{ang}.dat", 'w') as f: + with open(elem_file_dir/f"{keyname}-{elem}_{ang}.dat", 'w') as f: header_list.append( - f"\tpartial DOS for atom species: {elem}") + f"Partial DOS for {keyname}: {elem}") header_list.append('') for orb in self.orbitals: - if orb["species"] == elem and orb["l"] == l_index: + if orb[keyname] == elem and orb["l"] == l_index: header_list.append( - f"\tAdd data for atom_index ={orb['atom_index']:4d}, l,m,z={orb['l']:3d}, {orb['m']:3d}, {orb['z']:3d}") + f"\tAdd data for index ={orb['index']:4d}, atom_index ={orb['atom_index']:4d}, element ={orb['species']:4s}, l,m,z={orb['l']:3d}, {orb['m']:3d}, {orb['z']:3d}") header_list.append('') header_list.append( '\tEnergy'+10*' '+'spin 1'+8*' '+'spin 2') @@ -395,6 +327,30 @@ def write(self, species: Union[Sequence[str], Dict[str, List[int]], Dict[str, Di header = '\n'.join(header_list) np.savetxt(f, data, fmt=fmt, header=header) + def write(self, + index: Union[Sequence[int], Dict[int, List[int]], + Dict[int, Dict[int, List[int]]]] = [], + atom_index: Union[Sequence[int], Dict[int, List[int]], + Dict[int, Dict[int, List[int]]]] = [], + species: Union[Sequence[str], Dict[str, List[int]], + Dict[str, Dict[int, List[int]]]] = [], + outdir: PathLike = './' + ): + """Write parsed partial dos data to files + + Args: + index (Union[Sequence[int], Dict[int, List[int]], Dict[int, Dict[int, List[int]]]], optional): extract PDOS of each atom. Defaults to []. + atom_index (Union[Sequence[int], Dict[int, List[int]], Dict[int, Dict[int, List[int]]]], optional): extract PDOS of each atom with same atom_index. Defaults to []. + species (Union[Sequence[str], Dict[str, List[int]], Dict[str, Dict[int, List[int]]]], optional): extract PDOS of each atom with same species. Defaults to []. + outdir (PathLike, optional): directory of parsed PDOS files. Defaults to './'. + """ + if index: + self._write(index, keyname='index', outdir=outdir) + if atom_index: + self._write(atom_index, keyname='atom_index', outdir=outdir) + if species: + self._write(species, keyname='species', outdir=outdir) + def _shift_energy(self, efermi: float = 0, shift: bool = False, prec: float = 0.01): tdos = self._all_sum() if shift: @@ -407,10 +363,36 @@ def _shift_energy(self, efermi: float = 0, shift: bool = False, prec: float = 0. return energy_f, tdos - def plot(self, fig: Figure, ax: Union[axes.Axes, Sequence[axes.Axes]], species: Union[Sequence[str], Dict[str, List[int]], Dict[str, Dict[str, List[int]]]] = [], efermi: float = 0, energy_range: Sequence[float] = [], dos_range: Sequence[float] = [], shift: bool = False, prec: float = 0.01, **kwargs): - """Plot partial DOS""" + def _parial_plot(self, + fig: Figure, + ax: Union[axes.Axes, Sequence[axes.Axes]], + species: Union[Sequence[Any], Dict[Any, List[int]], + Dict[Any, Dict[int, List[int]]]] = [], + efermi: float = 0, + energy_range: Sequence[float] = [], + dos_range: Sequence[float] = [], + shift: bool = False, + prec: float = 0.01, + keyname: str = '', + **kwargs): + """Plot parsed partial dos data - dos = self.parse(species) + Args: + fig (Figure): object of matplotlib.figure.Figure + ax (Union[axes.Axes, Sequence[axes.Axes]]): object of matplotlib.axes.Axes or a list of this objects + species (Union[Sequence[Any], Dict[Any, List[int]], Dict[Any, Dict[int, List[int]]]], optional): list of atomic species(index or atom index) or dict of atomic species(index or atom index) and its angular momentum list. Defaults to []. + efermi (float, optional): fermi level in unit eV. Defaults to 0. + energy_range (Sequence[float], optional): energy range in unit eV for plotting. Defaults to []. + dos_range (Sequence[float], optional): dos range for plotting. Defaults to []. + shift (bool, optional): if shift energy by fermi level and set the VBM to zero, or not. Defaults to False. + prec (float, optional): precision for treating dos as zero. Defaults to 0.01. + keyname (str, optional): the keyword that extracts the PDOS. Defaults to ''. + + Returns: + DOSPlot object: for manually plotting picture with dosplot.ax + """ + + dos, totnum = parse_projected_data(self.orbitals, species, keyname) energy_f, tdos = self._shift_energy(efermi, shift, prec) if not species: @@ -421,10 +403,10 @@ def plot(self, fig: Figure, ax: Union[axes.Axes, Sequence[axes.Axes]], species: notes=dosplot.plot_params["notes"]) else: dosplot._set_figure(energy_range, dos_range) - + return dosplot - elif isinstance(species, (list, tuple)): + if isinstance(species, (list, tuple)): dosplot = DOSPlot(fig, ax, **kwargs) if "xlabel_params" in dosplot.plot_params.keys(): dosplot.ax.set_xlabel("Energy(eV)", ** @@ -473,24 +455,74 @@ def plot(self, fig: Figure, ax: Union[axes.Axes, Sequence[axes.Axes]], species: return dosplots + def plot(self, + fig: Figure, + ax: Union[axes.Axes, Sequence[axes.Axes]], + index: Union[Sequence[int], Dict[int, List[int]], + Dict[int, Dict[int, List[int]]]] = [], + atom_index: Union[Sequence[int], Dict[int, List[int]], + Dict[int, Dict[int, List[int]]]] = [], + species: Union[Sequence[str], Dict[str, List[int]], + Dict[str, Dict[int, List[int]]]] = [], + efermi: float = 0, + energy_range: Sequence[float] = [], + dos_range: Sequence[float] = [], + shift: bool = False, + prec: float = 0.01, + **kwargs): + """Plot parsed partial dos data + + Args: + fig (Figure): object of matplotlib.figure.Figure + ax (Union[axes.Axes, Sequence[axes.Axes]]): object of matplotlib.axes.Axes or a list of this objects + index (Union[Sequence[int], Dict[int, List[int]], Dict[int, Dict[int, List[int]]]], optional): extract PDOS of each atom. Defaults to []. + atom_index (Union[Sequence[int], Dict[int, List[int]], Dict[int, Dict[int, List[int]]]], optional): extract PDOS of each atom with same atom_index. Defaults to []. + species (Union[Sequence[str], Dict[str, List[int]], Dict[str, Dict[int, List[int]]]], optional): extract PDOS of each atom with same species. Defaults to []. + efermi (float, optional): fermi level in unit eV. Defaults to 0. + energy_range (Sequence[float], optional): energy range in unit eV for plotting. Defaults to []. + dos_range (Sequence[float], optional): dos range for plotting. Defaults to []. + shift (bool, optional): if shift energy by fermi level and set the VBM to zero, or not. Defaults to False. + prec (float, optional): precision for treating dos as zero. Defaults to 0.01. + + Returns: + DOSPlot object: for manually plotting picture with dosplot.ax + """ + + if not index and not atom_index and not species: + dosplot = self._parial_plot(fig=fig, ax=ax, species=[ + ], efermi=efermi, energy_range=energy_range, dos_range=dos_range, shift=shift, prec=prec, keyname='', **kwargs) + if index: + dosplot = self._parial_plot(fig=fig, ax=ax, species=index, efermi=efermi, energy_range=energy_range, + dos_range=dos_range, shift=shift, prec=prec, keyname='index', **kwargs) + if atom_index: + dosplot = self._parial_plot(fig=fig, ax=ax, species=atom_index, efermi=efermi, energy_range=energy_range, + dos_range=dos_range, shift=shift, prec=prec, keyname='atom_index', **kwargs) + if species: + dosplot = self._parial_plot(fig=fig, ax=ax, species=species, efermi=efermi, energy_range=energy_range, + dos_range=dos_range, shift=shift, prec=prec, keyname='species', **kwargs) + + return dosplot + if __name__ == "__main__": import matplotlib.pyplot as plt pdosfile = r"C:\Users\YY.Ji\Desktop\PDOS" pdos = PDOS(pdosfile) - species = {"Cs": [0, 1], "Na": [0, 1]} - fig, ax = plt.subplots(2, 1, sharex=True) - energy_range = [-6, 10] + #species = {"Ag": [2], "Cl": [1], "In": [0]} + atom_index = {8: {2: [1, 2]}, 4: {2: [1, 2]}, 10: [1, 2]} + fig, ax = plt.subplots(3, 1, sharex=True) + energy_range = [-1.5, 6] dos_range = [0, 5] - dosplots = pdos.plot(fig, ax, species, efermi=5, shift=True, - energy_range=energy_range, dos_range=dos_range, notes=[{'s': '(a)'}, {'s': '(b)'}]) + dosplots = pdos.plot(fig, ax, atom_index=atom_index, efermi=5, shift=True, + energy_range=energy_range, dos_range=dos_range, notes=[{'s': '(a)'}, {'s': '(b)'}, {'s': '(c)'}]) fig.savefig("pdos.png") - - tdosfile = r"C:\Users\YY.Ji\Desktop\TDOS" - tdos = TDOS(tdosfile) - fig, ax = plt.subplots() - energy_range = [-6, 10] - dos_range = [0, 5] - dosplots = tdos.plot(fig, ax, efermi=5, shift=True, - energy_range=energy_range, dos_range=dos_range, notes={'s': '(a)'}) - fig.savefig("tdos.png") + pdos.write(atom_index=atom_index) + + #tdosfile = r"C:\Users\YY.Ji\Desktop\TDOS" + #tdos = TDOS(tdosfile) + #fig, ax = plt.subplots() + #energy_range = [-6, 10] + #dos_range = [0, 5] + # dosplots = tdos.plot(fig, ax, efermi=5, shift=True, + # energy_range=energy_range, dos_range=dos_range, notes={'s': '(a)'}) + # fig.savefig("tdos.png") diff --git a/tools/plot-tools/abacus_plot/main.py b/tools/plot-tools/abacus_plot/main.py index c90293755b..aea7bef259 100644 --- a/tools/plot-tools/abacus_plot/main.py +++ b/tools/plot-tools/abacus_plot/main.py @@ -6,10 +6,9 @@ ''' import argparse -from os import PathLike import matplotlib.pyplot as plt -from abacus_plot.band import BandPlot +from abacus_plot.band import Band, PBand from abacus_plot.dos import TDOS, PDOS from abacus_plot.utils import read_json @@ -19,25 +18,50 @@ class Show: @classmethod def show_cmdline(cls, args): - if args.band: - text = read_json(args.band) - datafile = text["bandfile"] + if args.band and not args.projected and not args.out: + text = read_json(args.file) + bandfile = text["bandfile"] kptfile = text["kptfile"] efermi = text.pop("efermi", 0.0) energy_range = text.pop("energy_range", []) shift = text.pop("shift", False) figsize = text.pop("figsize", (12, 10)) fig, ax = plt.subplots(figsize=figsize) - band = BandPlot(fig, ax) - if isinstance(datafile, (str, PathLike)): - band.singleplot(datafile, kptfile, efermi, energy_range, shift) - elif isinstance(datafile, (list, tuple)): - band.multiplot(datafile, kptfile, efermi, energy_range, shift) + band = Band(bandfile, kptfile) + band.plot(fig=fig, ax=ax, efermi=efermi, energy_range=energy_range, shift=shift, **text) dpi = text.pop("dpi", 400) fig.savefig(text.pop("outfile", "band.png"), dpi=dpi) - if args.tdos: - text = read_json(args.tdos) + if args.band and args.projected: + text = read_json(args.file) + bandfile = text["bandfile"] + kptfile = text["kptfile"] + efermi = text.pop("efermi", 0.0) + energy_range = text.pop("energy_range", []) + shift = text.pop("shift", False) + figsize = text.pop("figsize", (12, 10)) + fig, ax = plt.subplots(figsize=figsize) + index = text.pop("index", []) + atom_index = text.pop("atom_index", []) + species = text.pop("species", []) + outdir = text.pop("outdir", './') + cmapname = text.pop("cmapname", 'jet') + pband = PBand(bandfile, kptfile) + pband.plot(fig=fig, ax=ax, index=index, atom_index=atom_index, species=species, efermi=efermi, energy_range=energy_range, shift=shift, outdir=outdir, cmapname=cmapname, **text) + + if args.band and args.out: + text = read_json(args.file) + bandfile = text["bandfile"] + kptfile = text["kptfile"] + index = text.pop("index", []) + atom_index = text.pop("atom_index", []) + species = text.pop("species", []) + outdir = text.pop("outdir", './') + pdos = PBand(bandfile, kptfile) + pdos.write(index=index, atom_index=atom_index, species=species, outdir=outdir) + + if args.dos and not args.projected and not args.out: + text = read_json(args.file) tdosfile = text.pop("tdosfile", '') efermi = text.pop("efermi", 0.0) energy_range = text.pop("energy_range", []) @@ -52,45 +76,51 @@ def show_cmdline(cls, args): dpi = text.pop("dpi", 400) fig.savefig(text.pop("tdosfig", "tdos.png"), dpi=dpi) - if args.pdos: - text = read_json(args.pdos) + if args.dos and args.projected: + text = read_json(args.file) pdosfile = text.pop("pdosfile", '') efermi = text.pop("efermi", 0.0) energy_range = text.pop("energy_range", []) dos_range = text.pop("dos_range", []) shift = text.pop("shift", False) + index = text.pop("index", []) + atom_index = text.pop("atom_index", []) species = text.pop("species", []) figsize = text.pop("figsize", (12, 10)) fig, ax = plt.subplots( len(species), 1, sharex=True, figsize=figsize) prec = text.pop("prec", 0.01) pdos = PDOS(pdosfile) - pdos.plot(fig, ax, species, efermi=efermi, shift=shift, + pdos.plot(fig=fig, ax=ax, index=index, atom_index=atom_index, species=species, efermi=efermi, shift=shift, energy_range=energy_range, dos_range=dos_range, prec=prec, **text) dpi = text.pop("dpi", 400) fig.savefig(text.pop("pdosfig", "pdos.png"), dpi=dpi) - if args.out_pdos: - text = read_json(args.out_pdos) + if args.dos and args.out: + text = read_json(args.file) pdosfile = text.pop("pdosfile", '') + index = text.pop("index", []) + atom_index = text.pop("atom_index", []) species = text.pop("species", []) outdir = text.pop("outdir", './') pdos = PDOS(pdosfile) - pdos.write(species, outdir) + pdos.write(index=index, atom_index=atom_index, species=species, outdir=outdir) def main(): parser = argparse.ArgumentParser( prog='abacus-plot', description='Plotting tools for ABACUS') # Show - parser.add_argument('-b', '--band', dest='band', type=str, + parser.add_argument('-f', '--file', dest='file', type=str, nargs='?', const='config.json', + default='config.json', help='profile with format json') + parser.add_argument('-b', '--band', dest='band', nargs='?', const=1, type=int, default=None, help='plot band structure and show band information.') - parser.add_argument('-t', '--tdos', dest='tdos', type=str, - default=None, help='plot total density of state(TDOS).') - parser.add_argument('-p', '--pdos', dest='pdos', type=str, - default=None, help='plot partial density of state(PDOS).') - parser.add_argument('-o', '--out_pdos', dest='out_pdos', type=str, - default=None, help='output partial density of state(PDOS).') + parser.add_argument('-d', '--dos', dest='dos', nargs='?', const=1, type=int, + default=None, help='plot density of state(DOS).') + parser.add_argument('-p', '--projected', dest='projected', nargs='?', const=1, type=int, + default=None, help='plot projected band structure or partial density of state(PDOS), should be used with `-b` or `-d`.') + parser.add_argument('-o', '--out_parsed_data', dest='out', nargs='?', const=1, type=int, + default=None, help='output projected band structure or partial density of state(PDOS) to files, should be used with `-b` or `-d`.') parser.set_defaults(func=Show().show_cmdline) args = parser.parse_args() diff --git a/tools/plot-tools/abacus_plot/utils.py b/tools/plot-tools/abacus_plot/utils.py index 72be9e0b2a..8e31a87c58 100644 --- a/tools/plot-tools/abacus_plot/utils.py +++ b/tools/plot-tools/abacus_plot/utils.py @@ -9,7 +9,8 @@ import re import string from os import PathLike -from typing import List, Sequence, Union +from typing import List, Sequence, Union, Any, Dict +from collections import defaultdict import numpy as np @@ -21,6 +22,14 @@ def remove_empty(a: list) -> list: while [] in a: a.remove([]) +def handle_data(data): + data.remove('') + + def handle_elem(elem): + elist = elem.split(' ') + remove_empty(elist) # `list` will be modified in function + return elist + return list(map(handle_elem, data)) def skip_notes(line: str) -> str: """Delete comments lines with '#' or '//' @@ -257,3 +266,68 @@ def get_angular_momentum_name(l_index: int, m_index: int) -> str: """ return angular_momentum_name[l_index][m_index] + + +def parse_projected_data(orbitals, species: Union[Sequence[Any], Dict[Any, List[int]], Dict[Any, Dict[Any, List[int]]]], keyname=''): + """Extract projected data from file + + Args: + species (Union[Sequence[Any], Dict[Any, List[int]], Dict[Any, Dict[str, List[int]]]], optional): list of atomic species(index or atom index) or dict of atomic species(index or atom index) and its angular momentum list. Defaults to []. + keyname (str): the keyword that extracts the projected data. Allowed values: 'index', 'atom_index', 'species' + """ + + if isinstance(species, (list, tuple)): + data = {} + elements = species + for elem in elements: + count = 0 + data_temp = np.zeros_like(orbitals[0]["data"], dtype=float) + for orb in orbitals: + if orb[keyname] == elem: + data_temp += orb["data"] + count += 1 + if count: + data[elem] = data_temp + + return data, len(elements) + + elif isinstance(species, dict): + data = defaultdict(dict) + elements = list(species.keys()) + l = list(species.values()) + totnum = 0 + for i, elem in enumerate(elements): + if isinstance(l[i], dict): + for ang, mag in l[i].items(): + l_count = 0 + l_index = int(ang) + l_data = {} + for m_index in mag: + m_count = 0 + data_temp = np.zeros_like( + orbitals[0]["data"], dtype=float) + for orb in orbitals: + if orb[keyname] == elem and orb["l"] == l_index and orb["m"] == m_index: + data_temp += orb["data"] + m_count += 1 + l_count += 1 + if m_count: + l_data[m_index] = data_temp + totnum += 1 + if l_count: + data[elem][l_index] = l_data + + elif isinstance(l[i], list): + for l_index in l[i]: + count = 0 + data_temp = np.zeros_like( + orbitals[0]["data"], dtype=float) + for orb in orbitals: + if orb[keyname] == elem and orb["l"] == l_index: + data_temp += orb["data"] + count += 1 + if count: + data[elem][l_index] = data_temp + totnum += 1 + + return data, totnum \ No newline at end of file diff --git a/tools/plot-tools/setup.py b/tools/plot-tools/setup.py index b76066b0b0..265dfd7d9d 100644 --- a/tools/plot-tools/setup.py +++ b/tools/plot-tools/setup.py @@ -10,7 +10,7 @@ if __name__ == "__main__": setup( name='abacus_plot', - version='1.1.0', + version='1.2.0', packages=find_packages(), description='Ploting tools for ABACUS', author='jiyuyang',