Skip to content

Commit

Permalink
Refactor: move "hPsi-call-act" procedure from OperatorPW into basic…
Browse files Browse the repository at this point in the history
… `Operator` and redesign `act()` interface (#2912)

* move act() interface to basic operator

* fix an illegal call of get_ngk()

* move the act-based hPsi into basic Operator

* an developer-friendly interface of act()
  • Loading branch information
maki49 authored Sep 10, 2023
1 parent 7fc58ce commit eae28f7
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 174 deletions.
40 changes: 36 additions & 4 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "module_hamilt_general/operator.h"
#include "module_base/timer.h"

using namespace hamilt;

Expand Down Expand Up @@ -34,10 +35,42 @@ Operator<FPTYPE, Device>::~Operator()
}

template<typename FPTYPE, typename Device>
typename Operator<FPTYPE, Device>::hpsi_info Operator<FPTYPE, Device>::hPsi(hpsi_info&) const
typename Operator<FPTYPE, Device>::hpsi_info Operator<FPTYPE, Device>::hPsi(hpsi_info& input) const
{
ModuleBase::WARNING_QUIT("Operator::hPsi", "hPsi error!");
return hpsi_info(nullptr, 0, nullptr);
ModuleBase::timer::tick("Operator", "hPsi");
using syncmem_op = psi::memory::synchronize_memory_op<FPTYPE, Device, Device>;
auto psi_input = std::get<0>(input);
std::tuple<const FPTYPE*, int> psi_info = psi_input->to_range(std::get<1>(input));
int nbands = std::get<1>(psi_info);

FPTYPE* tmhpsi = this->get_hpsi(input);
const FPTYPE* tmpsi_in = std::get<0>(psi_info);
//if range in hpsi_info is illegal, the first return of to_range() would be nullptr
if (tmpsi_in == nullptr)
{
ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!");
}

this->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(this->ik));
Operator* node((Operator*)this->next_op);
while (node != nullptr)
{
node->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(node->ik));
node = (Operator*)(node->next_op);
}

ModuleBase::timer::tick("Operator", "hPsi");

//if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
FPTYPE* hpsi_pointer = std::get<2>(input);
if (this->in_place)
{
// ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
syncmem_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size());
delete this->hpsi;
this->hpsi = new psi::Psi<FPTYPE, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
}
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
}

template<typename FPTYPE, typename Device>
Expand Down Expand Up @@ -118,7 +151,6 @@ FPTYPE* Operator<FPTYPE, Device>::get_hpsi(const hpsi_info& info) const
return hpsi_pointer;
}


namespace hamilt {
template class Operator<float, psi::DEVICE_CPU>;
template class Operator<std::complex<float>, psi::DEVICE_CPU>;
Expand Down
19 changes: 17 additions & 2 deletions source/module_hamilt_general/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,30 @@ class Operator
//this is the core function for Operator
// do H|psi> from input |psi> ,

// output of hpsi would be first member of the returned tuple
/// as default, different operators donate hPsi independently
/// run this->act function for the first operator and run all act() for other nodes in chain table
/// if this procedure is not suitable for your operator, just override this function.
/// output of hpsi would be first member of the returned tuple
typedef std::tuple<const psi::Psi<FPTYPE, Device>*, const psi::Range, FPTYPE*> hpsi_info;
virtual hpsi_info hPsi(hpsi_info& input)const;

virtual void init(const int ik_in);

virtual void add(Operator* next);

virtual int get_ik() const {return this->ik;}
virtual int get_ik() const { return this->ik; }

///do operation : |hpsi_choosed> = V|psi_choosed>
///V is the target operator act on choosed psi, the consequence should be added to choosed hpsi
virtual void act(const int nbands,
const int nbasis,
const int npol,
const FPTYPE* tmpsi_in,
FPTYPE* tmhpsi,
const int ngk_ik = 0)const {};

/// an developer-friendly interface for act() function
virtual psi::Psi<FPTYPE> act(const psi::Psi<FPTYPE>& psi_in) const { return psi_in; };

Operator* next_op = nullptr;

Expand Down
25 changes: 13 additions & 12 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,27 @@ Ekinetic<OperatorPW<FPTYPE, Device>>::~Ekinetic() {}

template<typename FPTYPE, typename Device>
void Ekinetic<OperatorPW<FPTYPE, Device>>::act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi)const
const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk_ik)const
{
ModuleBase::timer::tick("Operator", "EkineticPW");
const int npw = psi_in->get_ngk(this->ik);
this->max_npw = psi_in->get_nbasis() / psi_in->npol;
ModuleBase::timer::tick("Operator", "EkineticPW");
int max_npw = nbasis / npol;

const FPTYPE *gk2_ik = &(this->gk2[this->ik * this->gk2_col]);
// denghui added 20221019
ekinetic_op()(this->ctx, n_npwx, npw, this->max_npw, tpiba2, gk2_ik, tmhpsi, tmpsi_in);
// for (int ib = 0; ib < n_npwx; ++ib)
ekinetic_op()(this->ctx, nbands, ngk_ik, max_npw, tpiba2, gk2_ik, tmhpsi, tmpsi_in);
// for (int ib = 0; ib < nbands; ++ib)
// {
// for (int ig = 0; ig < npw; ++ig)
// for (int ig = 0; ig < ngk_ik; ++ig)
// {
// tmhpsi[ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ig];
// }
// tmhpsi += this->max_npw;
// tmpsi_in += this->max_npw;
// tmhpsi += max_npw;
// tmpsi_in += max_npw;
// }
ModuleBase::timer::tick("Operator", "EkineticPW");
}
Expand Down
15 changes: 6 additions & 9 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ class Ekinetic<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>

virtual ~Ekinetic();

virtual void act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi)const override;
virtual void act(const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk_ik = 0)const override;

// denghuilu added for copy construct at 20221105
int get_gk2_row() const {return this->gk2_row;}
Expand All @@ -49,10 +50,6 @@ class Ekinetic<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>

private:

mutable int max_npw = 0;

mutable int npol = 0;

FPTYPE tpiba2 = 0.0;
const FPTYPE* gk2 = nullptr;
int gk2_row = 0;
Expand Down
25 changes: 12 additions & 13 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ Meta<OperatorPW<FPTYPE, Device>>::~Meta()

template<typename FPTYPE, typename Device>
void Meta<OperatorPW<FPTYPE, Device>>::act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi
)const
const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk_ik)const
{
if (XC_Functional::get_func_type() != 3)
{
Expand All @@ -52,29 +53,27 @@ void Meta<OperatorPW<FPTYPE, Device>>::act(

ModuleBase::timer::tick("Operator", "MetaPW");

const int npw = psi_in->get_ngk(this->ik);
const int current_spin = this->isk[this->ik];
this->max_npw = psi_in->get_nbasis() / psi_in->npol;
int max_npw = nbasis / npol;
//npol == 2 case has not been considered
this->npol = psi_in->npol;

for (int ib = 0; ib < n_npwx; ++ib)
for (int ib = 0; ib < nbands; ++ib)
{
for (int j = 0; j < 3; j++)
{
meta_op()(this->ctx, this->ik, j, npw, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<FPTYPE>(), wfcpw->get_kvec_c_data<FPTYPE>(), tmpsi_in, this->porter);
meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<FPTYPE>(), wfcpw->get_kvec_c_data<FPTYPE>(), tmpsi_in, this->porter);
wfcpw->recip_to_real(this->ctx, this->porter, this->porter, this->ik);

if(this->vk_col != 0) {
vector_mul_vector_op()(this->ctx, this->vk_col, this->porter, this->porter, this->vk + current_spin * this->vk_col);
}

wfcpw->real_to_recip(this->ctx, this->porter, this->porter, this->ik);
meta_op()(this->ctx, this->ik, j, npw, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<FPTYPE>(), wfcpw->get_kvec_c_data<FPTYPE>(), this->porter, tmhpsi, true);
meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<FPTYPE>(), wfcpw->get_kvec_c_data<FPTYPE>(), this->porter, tmhpsi, true);

} // x,y,z directions
tmhpsi += this->max_npw;
tmpsi_in += this->max_npw;
tmhpsi += max_npw;
tmpsi_in += max_npw;
}
ModuleBase::timer::tick("Operator", "MetaPW");
}
Expand Down
10 changes: 6 additions & 4 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ class Meta<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>

virtual ~Meta();

virtual void act(const psi::Psi<std::complex<FPTYPE>, Device>* psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi) const override;
virtual void act(const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk = 0)const override;

// denghui added for copy constructor at 20221105
FPTYPE get_tpiba() const
Expand Down
35 changes: 18 additions & 17 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,30 +206,31 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::add_nonlocal_pp(std::complex<FPTYPE>
}

template<typename FPTYPE, typename Device>
void Nonlocal<OperatorPW<FPTYPE, Device>>::act
(
const psi::Psi<std::complex<FPTYPE>, Device>* psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi)const
void Nonlocal<OperatorPW<FPTYPE, Device>>::act(
const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk_ik)const
{
ModuleBase::timer::tick("Operator", "NonlocalPW");
this->npw = psi_in->get_ngk(this->ik);
this->max_npw = psi_in->get_nbasis() / psi_in->npol;
this->npol = psi_in->npol;
this->npw = ngk_ik;
this->max_npw = nbasis / npol;
this->npol = npol;

if (this->ppcell->nkb > 0)
{
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// qianrui optimize 2021-3-31
int nkb = this->ppcell->nkb;
if (this->nkb_m < n_npwx * nkb) {
resmem_complex_op()(this->ctx, this->becp, n_npwx * nkb, "Nonlocal<PW>::becp");
if (this->nkb_m < nbands * nkb) {
resmem_complex_op()(this->ctx, this->becp, nbands * nkb, "Nonlocal<PW>::becp");
}
// ModuleBase::ComplexMatrix becp(n_npwx, nkb, false);
// ModuleBase::ComplexMatrix becp(nbands, nkb, false);
char transa = 'C';
char transb = 'N';
if (n_npwx == 1)
if (nbands == 1)
{
int inc = 1;
// denghui replace 2022-10-20
Expand All @@ -250,7 +251,7 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::act
}
else
{
int npm = n_npwx;
int npm = nbands;
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// denghui replace 2022-10-20
gemm_op()(
Expand All @@ -264,16 +265,16 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::act
this->vkb,
this->ppcell->vkb.nc,
tmpsi_in,
this->max_npw,
max_npw,
&this->zero,
this->becp,
nkb
);
}

Parallel_Reduce::reduce_complex_double_pool(becp, nkb * n_npwx);
Parallel_Reduce::reduce_complex_double_pool(becp, nkb * nbands);

this->add_nonlocal_pp(tmhpsi, becp, n_npwx);
this->add_nonlocal_pp(tmhpsi, becp, nbands);
}
ModuleBase::timer::tick("Operator", "NonlocalPW");
}
Expand Down
12 changes: 6 additions & 6 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ class Nonlocal<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>

virtual void init(const int ik_in)override;

virtual void act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi
)const override;
virtual void act(const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk = 0)const override;

const int *get_isk() const {return this->isk;}
const pseudopot_cell_vnl *get_ppcell() const {return this->ppcell;}
Expand Down
48 changes: 0 additions & 48 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,54 +7,6 @@ using namespace hamilt;
template<typename FPTYPE, typename Device>
OperatorPW<FPTYPE, Device>::~OperatorPW(){};

template<typename FPTYPE, typename Device>
typename OperatorPW<FPTYPE, Device>::hpsi_info OperatorPW<FPTYPE, Device>::hPsi(
hpsi_info& input) const
{
ModuleBase::timer::tick("OperatorPW", "hPsi");
auto psi_input = std::get<0>(input);
std::tuple<const std::complex<FPTYPE>*, int> psi_info = psi_input->to_range(std::get<1>(input));
int n_npwx = std::get<1>(psi_info);

std::complex<FPTYPE> *tmhpsi = this->get_hpsi(input);
const std::complex<FPTYPE> *tmpsi_in = std::get<0>(psi_info);
//if range in hpsi_info is illegal, the first return of to_range() would be nullptr
if(tmpsi_in == nullptr)
{
ModuleBase::WARNING_QUIT("OperatorPW", "please choose correct range of psi for hPsi()!");
}

this->act(psi_input, n_npwx, tmpsi_in, tmhpsi);
OperatorPW* node((OperatorPW*)this->next_op);
while(node != nullptr)
{
node->act(psi_input, n_npwx, tmpsi_in, tmhpsi);
node = (OperatorPW*)(node->next_op);
}

ModuleBase::timer::tick("OperatorPW", "hPsi");

//if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
std::complex<FPTYPE>* hpsi_pointer = std::get<2>(input);
if(this->in_place)
{
// ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
syncmem_complex_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size());
delete this->hpsi;
this->hpsi = new psi::Psi<std::complex<FPTYPE>, Device>(hpsi_pointer, *psi_input, 1, n_npwx/psi_input->npol);
}
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, n_npwx/psi_input->npol), hpsi_pointer);
}

template<typename FPTYPE, typename Device>
void OperatorPW<FPTYPE, Device>::act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi) const
{
}

namespace hamilt {
template class OperatorPW<float, psi::DEVICE_CPU>;
template class OperatorPW<double, psi::DEVICE_CPU>;
Expand Down
Loading

0 comments on commit eae28f7

Please sign in to comment.