Skip to content

Commit

Permalink
Refactor: remove init_wfc&mem_saver&out_wfc_pw&out_wfc_r of w…
Browse files Browse the repository at this point in the history
…avefunc in abacus (#5557)

* remove wavefunc. init_wfc mem_saver out_wfc_pw out_wfc_r in abacus

* replace WFInit by PSIInit
  • Loading branch information
haozhihan authored Nov 22, 2024
1 parent 54b044b commit e0202e0
Show file tree
Hide file tree
Showing 13 changed files with 135 additions and 145 deletions.
8 changes: 0 additions & 8 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,6 @@ ESolver_KS<T, Device>::ESolver_KS()
///----------------------------------------------------------
p_chgmix = new Charge_Mixing();
p_chgmix->set_rhopw(this->pw_rho, this->pw_rhod);

///----------------------------------------------------------
/// wavefunc
///----------------------------------------------------------
this->wf.init_wfc = PARAM.inp.init_wfc;
this->wf.mem_saver = PARAM.inp.mem_saver;
this->wf.out_wfc_pw = PARAM.inp.out_wfc_pw;
this->wf.out_wfc_r = PARAM.inp.out_wfc_r;
}

//------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
// mohan move it outside 2011-01-13
// first need to calculate the weight according to
// electrons number.
if (istep == 0 && this->wf.init_wfc == "file")
if (istep == 0 && PARAM.inp.init_wfc == "file")
{
if (iter == 1)
{
Expand Down
6 changes: 3 additions & 3 deletions source/module_esolver/esolver_ks_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(const Input_para& inp, UnitCell&

void ESolver_KS_LCAO_TDDFT::hamilt2density_single(const int istep, const int iter, const double ethr)
{
if (wf.init_wfc == "file")
if (PARAM.inp.init_wfc == "file")
{
if (istep >= 1)
{
Expand Down Expand Up @@ -256,7 +256,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
const int nlocal = PARAM.globalv.nlocal;

// store wfc and Hk laststep
if (istep >= (wf.init_wfc == "file" ? 0 : 1) && this->conv_esolver)
if (istep >= (PARAM.inp.init_wfc == "file" ? 0 : 1) && this->conv_esolver)
{
if (this->psi_laststep == nullptr)
{
Expand Down Expand Up @@ -311,7 +311,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
}

// calculate energy density matrix for tddft
if (istep >= (wf.init_wfc == "file" ? 0 : 2) && module_tddft::Evolve_elec::td_edm == 0)
if (istep >= (PARAM.inp.init_wfc == "file" ? 0 : 2) && module_tddft::Evolve_elec::td_edm == 0)
{
elecstate::cal_edm_tddft(this->pv, this->pelec, this->kv, this->p_hamilt);
}
Expand Down
18 changes: 9 additions & 9 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,12 @@ void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCel
}

//! 7) prepare some parameters for electronic wave functions initilization
this->p_wf_init = new psi::WFInit<T, Device>(PARAM.inp.init_wfc,
PARAM.inp.ks_solver,
PARAM.inp.basis_type,
PARAM.inp.psi_initializer,
&this->wf,
this->pw_wfc);
this->p_wf_init = new psi::PSIInit<T, Device>(PARAM.inp.init_wfc,
PARAM.inp.ks_solver,
PARAM.inp.basis_type,
PARAM.inp.psi_initializer,
&this->wf,
this->pw_wfc);
this->p_wf_init->prepare_init(&(this->sf),
&ucell,
1,
Expand Down Expand Up @@ -547,7 +547,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(const int istep, int& iter)
}

// 4) Print out electronic wavefunctions
if (this->wf.out_wfc_pw == 1 || this->wf.out_wfc_pw == 2)
if (PARAM.inp.out_wfc_pw == 1 || PARAM.inp.out_wfc_pw == 2)
{
std::stringstream ssw;
ssw << PARAM.globalv.global_out_dir << "WAVEFUNC";
Expand All @@ -573,7 +573,7 @@ void ESolver_KS_PW<T, Device>::after_scf(const int istep)
ESolver_KS<T, Device>::after_scf(istep);

// 3) output wavefunctions
if (this->wf.out_wfc_pw == 1 || this->wf.out_wfc_pw == 2)
if (PARAM.inp.out_wfc_pw == 1 || PARAM.inp.out_wfc_pw == 2)
{
std::stringstream ssw;
ssw << PARAM.globalv.global_out_dir << "WAVEFUNC";
Expand Down Expand Up @@ -821,7 +821,7 @@ void ESolver_KS_PW<T, Device>::after_all_runners()
}

//! 6) Print out electronic wave functions in real space
if (this->wf.out_wfc_r == 1) // Peize Lin add 2021.11.21
if (PARAM.inp.out_wfc_r == 1) // Peize Lin add 2021.11.21
{
ModuleIO::write_psi_r_1(this->psi[0], this->pw_wfc, "wfc_realspace", true, this->kv);
}
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* psi = nullptr;

// psi_initializer controller
psi::WFInit<T, Device>* p_wf_init = nullptr;
psi::PSIInit<T, Device>* p_wf_init = nullptr;

Device* ctx = {};

Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void ESolver_SDFT_PW<T, Device>::before_all_runners(const Input_para& inp, UnitC
// 2) run "before_all_runners" in ESolver_KS
ESolver_KS_PW<T, Device>::before_all_runners(inp, ucell);

// 9) initialize the stochastic wave functions
// 3) initialize the stochastic wave functions
this->stowf.init(&this->kv, this->pw_wfc->npwk_max);
if (inp.nbands_sto != 0)
{
Expand All @@ -75,7 +75,7 @@ void ESolver_SDFT_PW<T, Device>::before_all_runners(const Input_para& inp, UnitC
}
this->stowf.sync_chi0();

// 10) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}>
// 4) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}>
size_t size = stowf.chi0->size();
this->stowf.shchi
= new psi::Psi<T, Device>(this->kv.get_nks(), this->stowf.nchip_max, this->wf.npwx, this->kv.ngk.data());
Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/lcao_others.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)
this->pv,
this->GG,
PARAM.inp.out_wfc_pw,
this->wf.out_wfc_r,
PARAM.inp.out_wfc_r,
this->kv,
PARAM.inp.nelec,
PARAM.inp.nbands_istate,
Expand All @@ -351,7 +351,7 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)
this->pv,
this->GK,
PARAM.inp.out_wfc_pw,
this->wf.out_wfc_r,
PARAM.inp.out_wfc_r,
this->kv,
PARAM.inp.nelec,
PARAM.inp.nbands_istate,
Expand Down
40 changes: 20 additions & 20 deletions source/module_hamilt_pw/hamilt_pwdft/wavefunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ psi::Psi<std::complex<double>>* wavefunc::allocate(const int nkstot, const int n
const int nks2 = nks;

psi::Psi<std::complex<double>>* psi_out = nullptr;
if (PARAM.inp.calculation == "nscf" && this->mem_saver == 1)
if (PARAM.inp.calculation == "nscf" && PARAM.inp.mem_saver == 1)
{
// initial psi rather than evc
psi_out = new psi::Psi<std::complex<double>>(1, PARAM.inp.nbands, npwx * PARAM.globalv.npol, ngk);
Expand Down Expand Up @@ -140,11 +140,11 @@ void wavefunc::wfcinit(psi::Psi<std::complex<double>>* psi_in, ModulePW::PW_Basi

int wavefunc::get_starting_nw() const
{
if (init_wfc == "file")
if (PARAM.inp.init_wfc == "file")
{
return PARAM.inp.nbands;
}
else if (init_wfc.substr(0, 6) == "atomic")
else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
{
if (GlobalC::ucell.natomwfc >= PARAM.inp.nbands)
{
Expand All @@ -164,7 +164,7 @@ int wavefunc::get_starting_nw() const
}
return std::max(GlobalC::ucell.natomwfc, PARAM.inp.nbands);
}
else if (init_wfc == "random")
else if (PARAM.inp.init_wfc == "random")
{
if (PARAM.inp.test_wf)
{
Expand Down Expand Up @@ -196,7 +196,7 @@ void diago_PAO_in_pw_k2(const int& ik,
const int nbands = wvf.get_nbands();
const int current_nbasis = wfc_basis->npwk[ik];

if (p_wf->init_wfc == "file")
if (PARAM.inp.init_wfc == "file")
{
ModuleBase::ComplexMatrix wfcatom(nbands, nbasis);
std::stringstream filename;
Expand Down Expand Up @@ -263,7 +263,7 @@ void diago_PAO_in_pw_k2(const int& ik,
}
*/

if (p_wf->init_wfc == "random" || (p_wf->init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
if (PARAM.inp.init_wfc == "random" || (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
{
p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis);

Expand All @@ -280,7 +280,7 @@ void diago_PAO_in_pw_k2(const int& ik,
}
}
}
else if (p_wf->init_wfc.substr(0, 6) == "atomic")
else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
{
ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc
if (PARAM.inp.test_wf) {
Expand All @@ -296,7 +296,7 @@ void diago_PAO_in_pw_k2(const int& ik,
PARAM.globalv.nqx,
PARAM.globalv.dq);

if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
{
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
}
Expand Down Expand Up @@ -355,7 +355,7 @@ void diago_PAO_in_pw_k2(const int& ik,
const int nbands = wvf.get_nbands();
const int current_nbasis = wfc_basis->npwk[ik];

if (p_wf->init_wfc == "file")
if (PARAM.inp.init_wfc == "file")
{
ModuleBase::ComplexMatrix wfcatom(nbands, nbasis);
std::stringstream filename;
Expand Down Expand Up @@ -420,7 +420,7 @@ void diago_PAO_in_pw_k2(const int& ik,
assert(starting_nw > 0);
std::vector<double> etatom(starting_nw, 0.0);

if (p_wf->init_wfc == "random" || (p_wf->init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
if (PARAM.inp.init_wfc == "random" || (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
{
p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis);
if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02
Expand All @@ -436,7 +436,7 @@ void diago_PAO_in_pw_k2(const int& ik,
}
}
}
else if (p_wf->init_wfc.substr(0, 6) == "atomic")
else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
{
ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc
if (PARAM.inp.test_wf)
Expand All @@ -453,7 +453,7 @@ void diago_PAO_in_pw_k2(const int& ik,
PARAM.globalv.nqx,
PARAM.globalv.dq);

if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
{
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
}
Expand Down Expand Up @@ -534,7 +534,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
int starting_nw = nbands;

ModuleBase::ComplexMatrix wfcatom(nbands, nbasis);
if (p_wf->init_wfc == "file")
if (PARAM.inp.init_wfc == "file")
{
std::stringstream filename;
int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot);
Expand All @@ -550,7 +550,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
if (PARAM.inp.test_wf)
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw);

if (p_wf->init_wfc.substr(0, 6) == "atomic")
if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
{
p_wf->atomic_wfc(ik,
current_nbasis,
Expand All @@ -560,7 +560,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
GlobalC::ppcell.tab_at,
PARAM.globalv.nqx,
PARAM.globalv.dq);
if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
{
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
}
Expand All @@ -571,7 +571,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
//====================================================
p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis);
}
else if (p_wf->init_wfc == "random")
else if (PARAM.inp.init_wfc == "random")
{
p_wf->random(wfcatom.c, 0, nbands, ik, wfc_basis);
}
Expand Down Expand Up @@ -638,7 +638,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
int starting_nw = nbands;

ModuleBase::ComplexMatrix wfcatom(nbands, nbasis);
if (p_wf->init_wfc == "file")
if (PARAM.inp.init_wfc == "file")
{
std::stringstream filename;
int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot);
Expand All @@ -653,7 +653,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
wfcatom.create(starting_nw, nbasis); // added by zhengdy-soc
if (PARAM.inp.test_wf)
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw);
if (p_wf->init_wfc.substr(0, 6) == "atomic")
if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
{
p_wf->atomic_wfc(ik,
current_nbasis,
Expand All @@ -663,7 +663,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
GlobalC::ppcell.tab_at,
PARAM.globalv.nqx,
PARAM.globalv.dq);
if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
{
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
}
Expand All @@ -674,7 +674,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
//====================================================
p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis);
}
else if (p_wf->init_wfc == "random")
else if (PARAM.inp.init_wfc == "random")
{
p_wf->random(wfcatom.c, 0, nbands, ik, wfc_basis);
}
Expand Down
56 changes: 25 additions & 31 deletions source/module_hamilt_pw/hamilt_pwdft/wavefunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,50 +10,44 @@

class wavefunc : public WF_atomic
{
public:

public:
wavefunc();
~wavefunc();

// allocate memory
psi::Psi<std::complex<double>>* allocate(const int nkstot, const int nks, const int* ngk, const int npwx);

int out_wfc_pw = 0; //qianrui modify 2020-10-19
int out_wfc_r = 0; // Peize Lin add 2021.11.21
int nkstot = 0; // total number of k-points for all pools

// init_wfc : "random",or "atomic" or "file"
std::string init_wfc;
int nkstot = 0; // total number of k-points for all pools
int mem_saver = 0; // 1: save evc when doing nscf calculation.
void wfcinit(psi::Psi<std::complex<double>>* psi_in, ModulePW::PW_Basis_K* wfc_basis);
int get_starting_nw(void)const;

void init_after_vc(const int nks); //LiuXh 20180515
};
int get_starting_nw(void) const;

void init_after_vc(const int nks); // LiuXh 20180515
};

namespace hamilt
{

void diago_PAO_in_pw_k2(const int &ik,
psi::Psi<std::complex<float>> &wvf,
ModulePW::PW_Basis_K *wfc_basis,
wavefunc *p_wf,
hamilt::Hamilt<std::complex<float>> *phm_in = nullptr);
void diago_PAO_in_pw_k2(const int &ik,
psi::Psi<std::complex<double>> &wvf,
ModulePW::PW_Basis_K *wfc_basis,
wavefunc *p_wf,
hamilt::Hamilt<std::complex<double>> *phm_in = nullptr);
void diago_PAO_in_pw_k2(const int &ik, ModuleBase::ComplexMatrix &wvf, wavefunc *p_wf);
void diago_PAO_in_pw_k2(const int& ik,
psi::Psi<std::complex<float>>& wvf,
ModulePW::PW_Basis_K* wfc_basis,
wavefunc* p_wf,
hamilt::Hamilt<std::complex<float>>* phm_in = nullptr);
void diago_PAO_in_pw_k2(const int& ik,
psi::Psi<std::complex<double>>& wvf,
ModulePW::PW_Basis_K* wfc_basis,
wavefunc* p_wf,
hamilt::Hamilt<std::complex<double>>* phm_in = nullptr);
void diago_PAO_in_pw_k2(const int& ik, ModuleBase::ComplexMatrix& wvf, wavefunc* p_wf);

template <typename FPTYPE, typename Device>
void diago_PAO_in_pw_k2(const Device *ctx,
const int &ik,
psi::Psi<std::complex<FPTYPE>, Device> &wvf,
ModulePW::PW_Basis_K *wfc_basis,
wavefunc *p_wf,
hamilt::Hamilt<std::complex<FPTYPE>, Device> *phm_in = nullptr);
}

#endif //wavefunc
void diago_PAO_in_pw_k2(const Device* ctx,
const int& ik,
psi::Psi<std::complex<FPTYPE>, Device>& wvf,
ModulePW::PW_Basis_K* wfc_basis,
wavefunc* p_wf,
hamilt::Hamilt<std::complex<FPTYPE>, Device>* phm_in = nullptr);
} // namespace hamilt

#endif // wavefunc
Loading

0 comments on commit e0202e0

Please sign in to comment.