From dd7fee24d19a64fe4f08b30f38f68105b405d7b5 Mon Sep 17 00:00:00 2001 From: "Jamal I. Mustafa" Date: Tue, 23 Aug 2016 23:05:14 -0700 Subject: [PATCH] add the package --- setup.py | 28 ++++++ w90utils/__init__.py | 8 ++ w90utils/_amn.py | 47 ++++++++++ w90utils/_mmn.py | 65 ++++++++++++++ w90utils/_utils.py | 8 ++ w90utils/io/__init__.py | 74 +++++++++++++++ w90utils/io/_amn.py | 52 +++++++++++ w90utils/io/_bands.py | 34 +++++++ w90utils/io/_chk.py | 53 +++++++++++ w90utils/io/_eig.py | 84 +++++++++++++++++ w90utils/io/_hr.py | 36 ++++++++ w90utils/io/_mmn.py | 82 +++++++++++++++++ w90utils/io/_orbitals.py | 52 +++++++++++ w90utils/io/_unk.py | 35 ++++++++ w90utils/io/_utils.py | 120 +++++++++++++++++++++++++ w90utils/io/nnkp.py | 134 +++++++++++++++++++++++++++ w90utils/io/postw90.py | 117 ++++++++++++++++++++++++ w90utils/io/utils.py | 13 +++ w90utils/io/win.py | 190 +++++++++++++++++++++++++++++++++++++++ w90utils/io/wout.py | 50 +++++++++++ w90utils/sprd.py | 160 +++++++++++++++++++++++++++++++++ 21 files changed, 1442 insertions(+) create mode 100644 setup.py create mode 100644 w90utils/__init__.py create mode 100644 w90utils/_amn.py create mode 100644 w90utils/_mmn.py create mode 100644 w90utils/_utils.py create mode 100644 w90utils/io/__init__.py create mode 100644 w90utils/io/_amn.py create mode 100644 w90utils/io/_bands.py create mode 100644 w90utils/io/_chk.py create mode 100644 w90utils/io/_eig.py create mode 100644 w90utils/io/_hr.py create mode 100644 w90utils/io/_mmn.py create mode 100644 w90utils/io/_orbitals.py create mode 100644 w90utils/io/_unk.py create mode 100644 w90utils/io/_utils.py create mode 100644 w90utils/io/nnkp.py create mode 100644 w90utils/io/postw90.py create mode 100644 w90utils/io/utils.py create mode 100644 w90utils/io/win.py create mode 100644 w90utils/io/wout.py create mode 100644 w90utils/sprd.py diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8a4be1b --- /dev/null +++ b/setup.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import, division, print_function + +from setuptools import setup, find_packages + + +setup( + name='wannier90-utils', + version='0.1.0', + description='Wannier90 utility library', + author='Jamal I. Mustafa', + author_email='jimustafa@gmail.com', + license='BSD', + classifiers=[ + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: BSD License', + 'Natural Language :: English', + 'Operating System :: POSIX :: Linux', + 'Programming Language :: Python :: 2.7', + 'Topic :: Scientific/Engineering :: Physics', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + packages=find_packages(exclude=['docs', 'tests']), + install_requires=[ + 'numpy', + 'scipy', + ], + tests_require=['pytest'], +) diff --git a/w90utils/__init__.py b/w90utils/__init__.py new file mode 100644 index 0000000..fc5d2c1 --- /dev/null +++ b/w90utils/__init__.py @@ -0,0 +1,8 @@ +"""Wannier90 utility library""" +from __future__ import absolute_import, division, print_function + +from . import io +from . import sprd +from ._amn import expand_amn +from ._mmn import rotate_mmn +from ._utils import unitarize diff --git a/w90utils/_amn.py b/w90utils/_amn.py new file mode 100644 index 0000000..f538977 --- /dev/null +++ b/w90utils/_amn.py @@ -0,0 +1,47 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + + +def expand_amn(a, kpoints, idx, Rvectors, nproj_atom=None): + """ + Expand the projections matrix by translations of the orbitals + + Parameters + ---------- + a : ndarray, shape (nkpts, nbnds, nproj) + kpoints : ndarray, shape (nkpts, 3) + idx : ndarray + indices of translated orbitals + Rvectors: ndarray + translation vectors for the orbitals + nproj_atom: ndarray, optional + number of projections on each atom, with idx and Rvectors now describing + atoms instead of orbitals + + """ + assert len(Rvectors) == len(idx) + + if nproj_atom is not None: + assert len(nproj_atom) == len(idx) + idx_new = [] + Rvectors_new = [] + for iatom, i in enumerate(idx): + offset = np.sum(nproj_atom[:i]) + for j in range(nproj_atom[i]): + idx_new.append(offset+j) + Rvectors_new.append(Rvectors[iatom]) + + idx = idx_new + Rvectors = Rvectors_new + + nkpts, nbnds, nproj = a.shape + + a1 = np.zeros((nkpts, nbnds, len(idx)), dtype=complex) + + k_dot_R = np.einsum('ki,ri->kr', kpoints, Rvectors) + exp_factors = np.exp(-1j * 2*np.pi * k_dot_R) + + a1 = a[:, :, idx] * exp_factors[:, np.newaxis, :] + + return a1 diff --git a/w90utils/_mmn.py b/w90utils/_mmn.py new file mode 100644 index 0000000..ec2e4d0 --- /dev/null +++ b/w90utils/_mmn.py @@ -0,0 +1,65 @@ +from __future__ import division, print_function + +import numpy as np + + +def rotate_mmn(mmn, umn, kpb_kidx, window=None): + """ + Rotate the overlap matrices according to + :math:`U^{(\mathbf{k})\dagger}M^{(\mathbf{k},\mathbf{b})}U^{(\mathbf{k}+\mathbf{b})}` + + """ + (nkpts, nntot, nbnds, nbnds) = mmn.shape + nproj = umn[0].shape[1] + + mmn_rotated = np.empty((nkpts, nntot, nproj, nproj), dtype=complex) + + if window is not None: + for ikpt in range(nkpts): + for inn in range(nntot): + ikpb = kpb_kidx[ikpt][inn] + mmn_rotated[ikpt][inn] = ( + np.dot( + np.dot( + umn[ikpt].conj().T, mmn[ikpt][inn][window[ikpt]][:, window[ikpb]], + ), + umn[ikpb] + ) + ) + else: + for ikpt in range(nkpts): + for inn in range(nntot): + ikpb = kpb_kidx[ikpt][inn] + mmn_rotated[ikpt][inn] = np.dot(np.dot(umn[ikpt].conj().T, mmn[ikpt][inn]), umn[ikpb]) + + return mmn_rotated + + +# def change_gauge_k(m, u, setup_file): +# (nkpts, nntot, nbnds, nbnds) = m.shape +# nproj = u[0].shape[1] + +# m_rotated = np.zeros((nkpts, nntot, nproj, nbnds), dtype=complex) + +# for ikpt in range(nkpts): +# for inn in range(nntot): +# m_rotated[ikpt][inn] = np.dot(u[ikpt].conj().T, m[ikpt][inn]) + +# return m_rotated + + +# def change_gauge_kpb(m, u, setup_file, kpb_kidx=None): +# (nkpts, nntot, nbnds, nbnds) = m.shape +# nproj = u[0].shape[1] + +# if kpb_kidx is None: +# kpb_kidx = setup_file.kpb_kidx + +# m_rotated = np.zeros((nkpts, nntot, nbnds, nproj), dtype=complex) + +# for ikpt in range(nkpts): +# for inn in range(nntot): +# ikpb = kpb_kidx[ikpt][inn] +# m_rotated[ikpt][inn] = np.dot(m[ikpt][inn], u[ikpb]) + +# return m_rotated diff --git a/w90utils/_utils.py b/w90utils/_utils.py new file mode 100644 index 0000000..6177ca2 --- /dev/null +++ b/w90utils/_utils.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + + +def unitarize(a): + u, _, v = np.linalg.svd(a, full_matrices=False) + return np.einsum('...ik,...kj->...ij', u, v) diff --git a/w90utils/io/__init__.py b/w90utils/io/__init__.py new file mode 100644 index 0000000..e3f7c32 --- /dev/null +++ b/w90utils/io/__init__.py @@ -0,0 +1,74 @@ +"""Wannier90 I/O library""" +from __future__ import absolute_import, division, print_function +import collections + +from ._amn import * +from ._bands import * +from ._chk import * +from ._eig import * +from ._hr import * +from ._mmn import * +from ._unk import * +from . import nnkp +from . import postw90 +from . import utils +from . import win +from . import wout +from . import _utils + + +Wannier90Data = collections.namedtuple( + 'Wannier90Data', + [ + 'dlv', 'rlv', + 'amn', 'mmn', 'eig', + 'kpoints', 'kpb_kidx', 'kpb_g', + 'bv', 'bw', + 'length_unit', 'energy_unit' + ]) + + +def read_data(seedname='wannier', **kwargs): + """ + Read all Wannier90 input data files from the current directory. + + Parameters + ---------- + seedname : str, optional + seedname for the Wannier90 files, the default is "wannier" + + """ + dlv = kwargs.get('dlv', nnkp.read_dlv(seedname+'.nnkp', units='angstrom')) + rlv = kwargs.get('rlv', nnkp.read_rlv(seedname+'.nnkp', units='angstrom')) + try: + amn = kwargs['amn'] + except KeyError: + amn = read_amn(seedname+'.amn') + try: + mmn = kwargs['mmn'] + except KeyError: + mmn = read_mmn(seedname+'.mmn') + try: + eig = kwargs['eig'] + except KeyError: + eig = read_eig(seedname+'.eig') + kpoints = kwargs.get('kpoints', nnkp.read_kpoints(seedname+'.nnkp')) + kpb_kidx = kwargs.get('kpb_kidx', nnkp.read_nnkpts(seedname+'.nnkp')[0]) + kpb_g = kwargs.get('kpb_kidx', nnkp.read_nnkpts(seedname+'.nnkp')[1]) + bv = kwargs.get('bv', nnkp.read_bvectors(seedname+'.nnkp', units='angstrom')) + bw = kwargs.get('bw', _utils.bweights(bv)) + + return Wannier90Data( + dlv=dlv, + rlv=rlv, + amn=amn, + mmn=mmn, + eig=eig, + kpoints=kpoints, + kpb_kidx=kpb_kidx, + kpb_g=kpb_g, + bv=bv, + bw=bw, + length_unit='angstrom', + energy_unit='eV', + ) diff --git a/w90utils/io/_amn.py b/w90utils/io/_amn.py new file mode 100644 index 0000000..9297873 --- /dev/null +++ b/w90utils/io/_amn.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + + +__all__ = ['read_amn', 'write_amn'] + + +def read_amn(fname): + """ + Read AMN file. + + Parameters + ---------- + fname : str + + Returns + ------- + amn : ndarray, shape (nkpts, nbnds, nproj) + + """ + with open(fname, 'r') as f: + f.readline() # header + [nbnds, nkpts, nproj] = map(int, f.readline().split()) + data_str = f.read() + + raw_data = np.fromstring(data_str, sep='\n').reshape((nkpts*nbnds*nproj, 5)) + amn = raw_data[:, 3] + 1j*raw_data[:, 4] + amn = np.copy(np.transpose(amn.reshape((nbnds, nproj, nkpts), order='F'), axes=(2, 0, 1)), order='C') + + return amn + + +def write_amn(fname, amn, header='HEADER'): + r""" + Write :math:`A^{(\mathbf{k})}_{mn}` to AMN file. + + Parameters + ---------- + fname : str + amn : ndarray, shape (nkpts, nbnds, nproj) + header : str + + """ + (nkpts, nbnds, nproj) = amn.shape + indices = np.mgrid[:nbnds, :nproj, :nkpts].reshape((3, -1), order='F') + 1 + amn = np.transpose(amn, axes=(1, 2, 0)).flatten(order='F').view(float).reshape((-1, 2)) + data_out = np.column_stack((indices.transpose(), amn)) + with open(fname, 'w') as f: + print(header, file=f) + print('%13d%13d%13d' % (nbnds, nkpts, nproj), file=f) + np.savetxt(f, data_out, fmt='%5d%5d%5d%18.12f%18.12f') diff --git a/w90utils/io/_bands.py b/w90utils/io/_bands.py new file mode 100644 index 0000000..2749b69 --- /dev/null +++ b/w90utils/io/_bands.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + + +__all__ = ['read_kpoints', 'read_bands'] + + +def read_kpoints(fname): + raw_data = np.loadtxt(fname, skiprows=1) + kpoints = raw_data[:, (0, 1, 2)] + kweights = raw_data[:, 3] + + return kpoints + + +def read_bands(fname): + nkpts = None + with open(fname, 'r') as f: + for iln, line in enumerate(f): + if len(line.strip()) == 0: + nkpts = iln + break + + if nkpts is None: + raise Exception + + print(nkpts) + + raw_data = np.loadtxt(fname) + nbnds = len(raw_data) // nkpts + bands = raw_data[:, 1].reshape((nbnds, nkpts)).transpose() + + return bands diff --git a/w90utils/io/_chk.py b/w90utils/io/_chk.py new file mode 100644 index 0000000..0becae5 --- /dev/null +++ b/w90utils/io/_chk.py @@ -0,0 +1,53 @@ +from __future__ import absolute_import, division, print_function +import contextlib +import cStringIO as StringIO + +import numpy as np +from scipy.io import FortranFile + + +__all__ = ['CheckpointIO'] + + +class CheckpointIO(object): + def __init__(self, fname=None, auto_read=True): + if fname and auto_read: + self.from_file(fname) + + def from_file(self, fname): + with FortranFile(fname, 'r') as f: + self.header = ''.join(f.read_record('c')) + self.nbnds = f.read_ints()[0] + self.nbnds_excl = f.read_ints()[0] + self.bands_excl = f.read_ints() + self.dlv = f.read_reals().reshape((3, 3), order='F') + self.rlv = f.read_reals().reshape((3, 3), order='F') + self.nkpts = f.read_ints()[0] + self.grid_dims = f.read_ints() + self.kpoints = f.read_reals().reshape((-1, 3)) + self.nntot = f.read_ints()[0] + self.nwann = f.read_ints()[0] + self.chkpt = f.read_record('c') + self.disentanglement = bool(f.read_ints()[0]) + + if self.disentanglement: + self.omega_invariant = f.read_reals()[0] + self.windows = f.read_ints().reshape((self.nbnds, self.nkpts), order='F').transpose + f.read_ints() + self.umat_opt = np.transpose(f.read_reals().view(complex).reshape((self.nbnds, self.nwann, self.nkpts), order='F'), axes=(2, 0, 1)) + + self.umat = np.transpose(f.read_reals().view(complex).reshape((self.nwann, self.nwann, self.nkpts), order='F'), axes=(2, 0, 1)) + self.mmat = np.transpose(f.read_reals().view(complex).reshape((self.nwann, self.nwann, self.nntot, self.nkpts), order='F'), axes=(3, 2, 0, 1)) + self.wannier_centers = f.read_reals().reshape((-1, 3)) + self.wannier_spreads = f.read_reals() + + def __str__(self): + with contextlib.closing(StringIO.StringIO()) as sio: + print(self.header, file=sio) + print(self.nbnds, file=sio) + print(self.nbnds_excl, file=sio) + print(self.dlv, file=sio) + print(self.rlv, file=sio) + print(self.nkpts, file=sio) + s = sio.getvalue() + return s diff --git a/w90utils/io/_eig.py b/w90utils/io/_eig.py new file mode 100644 index 0000000..28b9716 --- /dev/null +++ b/w90utils/io/_eig.py @@ -0,0 +1,84 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + + +__all__ = ['read_eig', 'write_eig', 'read_hamiltonian'] + + +def read_eig(fname): + """ + Read EIG file. + + Parameters + ---------- + fname : str + + Returns + ------- + eig : ndarray, shape (nkpts, nbnds, nproj) + + """ + raw_data = np.loadtxt(fname) + + band_indices = raw_data[:, 0].astype(int) + kpoint_indices = raw_data[:, 1].astype(int) + + nbnds = np.max(band_indices) + nkpts = np.max(kpoint_indices) + + eig = raw_data[:, 2] + eig = eig.reshape((nkpts, nbnds)) + + return eig + + +def write_eig(fname, eig): + r""" + Write :math:`E_{n\mathbf{k}}` to EIG file. + + Parameters + ---------- + fname : eig + eig : ndarray, shape (nkpts, nbnds) + + + """ + nkpts, nbnds = eig.shape + indices = np.mgrid[:nbnds, :nkpts].reshape((2, nkpts*nbnds), order='F')+1 + + band_indices = indices[0] + kpoint_indices = indices[1] + + eig = eig.flatten() + + data = np.column_stack((band_indices, kpoint_indices, eig)) + + np.savetxt(fname, data, fmt='%5d%5d%18.12f') + + +def read_hamiltonian(fname): + """ + Read EIG file and return k-dependent Hamiltonian matrix. + + Parameters + ---------- + fname: str + + Returns + ------- + Hk: ndarray, shape (nkpts, nbnds, nbnds) + + """ + eig = read_eig(fname) + nkpts, nbnds = eig.shape + + Hk = np.zeros((nkpts, nbnds, nbnds)) + di = np.diag_indices(nbnds) + for ikpt in range(nkpts): + Hk[ikpt][di] = eig[ikpt] + + return Hk + + +read_hk = read_hamiltonian diff --git a/w90utils/io/_hr.py b/w90utils/io/_hr.py new file mode 100644 index 0000000..b3cc42f --- /dev/null +++ b/w90utils/io/_hr.py @@ -0,0 +1,36 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + + +__all__ = ['read_hr'] + + +def read_hr(fname): + with open(fname, 'r') as f: + contents = f.readlines() + + # header = contents[0] + nwann = int(contents[1].strip()) + nrpts = int(contents[2].strip()) + + nints_per_line = 15 # as indicated in Wannier90 User Guide Sec. 8.18 + # determine the number of lines the degeneracies span + ndegen_lines = nrpts // nints_per_line + if nrpts % 15 != 0: + ndegen_lines += 1 + + # read the degeneracy data + rdegen = np.fromstring(''.join(contents[3:3+ndegen_lines]), sep='\n') + + istart_hr = 3+ndegen_lines + raw_data = np.fromstring(''.join(contents[istart_hr:]), sep='\n').reshape((-1, 7)) + + hr = raw_data[:, 5] + 1j*raw_data[:, 6] + hr = hr.reshape((nrpts, nwann**2)).reshape((nrpts, nwann, nwann), order='F') + hr = np.copy(hr, order='C') + + Rvectors = np.copy(raw_data[:, :3].astype(int)[::nwann**2], order='C') + Rweights = 1 / rdegen + + return hr, Rvectors, Rweights diff --git a/w90utils/io/_mmn.py b/w90utils/io/_mmn.py new file mode 100644 index 0000000..b24b325 --- /dev/null +++ b/w90utils/io/_mmn.py @@ -0,0 +1,82 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + + +__all__ = ['read_mmn', 'write_mmn'] + + +def _process_mmn_file(fname): + with open(fname, 'r') as f: + contents = f.readlines() + + header = contents[0] + + [nbnds, nkpts, nntot] = np.fromstring(contents[1], sep=' ', dtype=int) + + nblks = nkpts * nntot + # a block consists of the header line and the following nbnds**2 lines + blk_len = nbnds**2 + 1 + # indices of the starting line of each block + blk_start_idx = range(2, len(contents), blk_len) + + kpb_kidx = np.zeros((nkpts, nntot), dtype=int) + kpb_g = np.zeros((nkpts, nntot, 3), dtype=int) + mmn = np.zeros((nkpts, nntot, nbnds, nbnds), dtype=complex) + for (iblk, istart) in enumerate(blk_start_idx): + # determine kpoint-index and nearest-neighbor index + # based on the index of the block + # -------------------------------------------------- + ikpt = iblk // nntot + if iblk % nntot == 0: + inn = 0 + else: + inn += 1 + # -------------------------------------------------- + block = contents[istart:(istart+blk_len)] + block_header = block[0] + kpb_kidx[ikpt][inn] = int(block_header.split()[1]) - 1 + kpb_g[ikpt][inn] = map(int, block_header.split()[2:]) + s = ''.join(block[1:]) + a = np.fromstring(s, sep='\n').view(complex) + mmn[ikpt, inn, :, :] = a.reshape((nbnds, nbnds), order='F') + + return mmn, kpb_kidx, kpb_g + + +def read_mmn(fname): + """ + Read MMN file + + Parameters + ---------- + fname : str + + Returns + ------- + mmn : ndarray, shape (nkpts, nntot, nbnds, nbnds) + + """ + return _process_mmn_file(fname)[0] + + +def write_mmn(fname, mmn, kpb_kidx, kpb_g): + """ + Write :math:`M^{(\mathbf{k},\mathbf{b})}_{mn}` to MMN file + + Parameters + ---------- + fname : str + mmn : ndarray, shape (nkpts, nntot, nbnds, nbnds) + + """ + nkpts = mmn.shape[0] + nntot = mmn.shape[1] + nbnds = mmn.shape[2] + with open(fname, 'w') as f: + print('DUMMY HEADER', file=f) + print('%12d%12d%12d' % (nbnds, nkpts, nntot), file=f) + for ikpt in range(nkpts): + for inn in range(nntot): + print('%5d%5d%5d%5d%5d' % ((ikpt+1, kpb_kidx[ikpt][inn]+1) + tuple(kpb_g[ikpt][inn])), file=f) + np.savetxt(f, mmn[ikpt][inn].flatten(order='F').view(float).reshape(-1, 2), fmt='%18.12f%18.12f') diff --git a/w90utils/io/_orbitals.py b/w90utils/io/_orbitals.py new file mode 100644 index 0000000..8be7ffc --- /dev/null +++ b/w90utils/io/_orbitals.py @@ -0,0 +1,52 @@ +"""Predefined trial orbitals""" + +atomic_orbitals = {} +atomic_orbitals[0] = {} +atomic_orbitals[0][1] = 's' +atomic_orbitals[1] = {} +atomic_orbitals[1][1] = 'pz' +atomic_orbitals[1][2] = 'px' +atomic_orbitals[1][3] = 'py' +atomic_orbitals[2] = {} +atomic_orbitals[2][1] = 'dz2' +atomic_orbitals[2][2] = 'dxz' +atomic_orbitals[2][3] = 'dyz' +atomic_orbitals[2][4] = 'dx2-y2' +atomic_orbitals[2][5] = 'dxy' +atomic_orbitals[3] = {} +atomic_orbitals[3][1] = 'fz3' +atomic_orbitals[3][2] = 'fxz2' +atomic_orbitals[3][3] = 'fyz2' +atomic_orbitals[3][4] = 'fz(x2-y2)' +atomic_orbitals[3][5] = 'fxyz' +atomic_orbitals[3][6] = 'fx(x2-3y2)' +atomic_orbitals[3][7] = 'fy(3x2-y2)' + +hybrid_orbitals = {} +hybrid_orbitals[-1] = {} +hybrid_orbitals[-1][1] = 'sp-1' +hybrid_orbitals[-1][2] = 'sp-2' +hybrid_orbitals[-2] = {} +hybrid_orbitals[-2][1] = 'sp2-1' +hybrid_orbitals[-2][2] = 'sp2-2' +hybrid_orbitals[-2][3] = 'sp2-3' +hybrid_orbitals[-3] = {} +hybrid_orbitals[-3][1] = 'sp3-1' +hybrid_orbitals[-3][2] = 'sp3-2' +hybrid_orbitals[-3][3] = 'sp3-3' +hybrid_orbitals[-3][4] = 'sp3-4' +hybrid_orbitals[-4] = {} +hybrid_orbitals[-4][1] = 'sp3d-1' +hybrid_orbitals[-4][2] = 'sp3d-2' +hybrid_orbitals[-4][3] = 'sp3d-3' +hybrid_orbitals[-4][4] = 'sp3d-4' +hybrid_orbitals[-4][5] = 'sp3d-5' +hybrid_orbitals[-5] = {} +hybrid_orbitals[-5][1] = 'sp3d2-1' +hybrid_orbitals[-5][2] = 'sp3d2-2' +hybrid_orbitals[-5][3] = 'sp3d2-3' +hybrid_orbitals[-5][4] = 'sp3d2-4' +hybrid_orbitals[-5][5] = 'sp3d2-5' +hybrid_orbitals[-5][6] = 'sp3d2-6' + +orbitals = dict(atomic_orbitals, **hybrid_orbitals) diff --git a/w90utils/io/_unk.py b/w90utils/io/_unk.py new file mode 100644 index 0000000..c190870 --- /dev/null +++ b/w90utils/io/_unk.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np +from scipy.io import FortranFile + + +# __all__ = ['read_unk', 'write_unk'] + + +# def read_unk(fname, gspace=False): +# with FortranFile(fname) as f: +# [ngx, ngy, ngz, ikpt, nbnds] = f.read_ints() + +# ngtot = ngx * ngy * ngz +# unk_r = np.zeros((nbnds, ngtot), dtype=complex) +# for ibnd in range(nbnds): +# unk_r[ibnd, :] = f.read_reals().view(complex) + +# unk_r = unk_r.reshape((nbnds, ngx, ngy, ngz), order='F') / np.sqrt(ngtot) + +# if gspace: +# unk_g = np.fft.fftn(unk_r, axes=(1, 2, 3)).reshape((nbnds, ngtot), order='F') / np.sqrt(ngtot) +# return unk_g +# else: +# return unk_r + + +# def write_unk(fname, unk, ikpt): +# (nbnds, ngx, ngy, ngz) = unk.shape + +# with FortranFile(fname, 'w') as f: +# f.write_record(np.array([ngx, ngy, ngz, ikpt+1, nbnds], dtype=np.int32)) + +# for ibnd in range(nbnds): +# f.write_record(unk[ibnd].ravel(order='F').view(float)) diff --git a/w90utils/io/_utils.py b/w90utils/io/_utils.py new file mode 100644 index 0000000..52ee057 --- /dev/null +++ b/w90utils/io/_utils.py @@ -0,0 +1,120 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np +from scipy.constants import codata + + +# energy conversions +Ha2eV = codata.value('Hartree energy in eV') +eV2Ha = 1/Ha2eV +Ha2meV = Ha2eV * 1000 +Ry2eV = codata.value('Rydberg constant times hc in eV') +Ha2Ry = 2.0 +Ry2Ha = 0.5 +Ha2K = codata.value('hartree-kelvin relationship') + +# length conversions +bohr2angstrom = codata.value('Bohr radius') * 1e10 +angstrom2bohr = 1 / bohr2angstrom + +_conversion_factors = { + 'angstrom': { + 'angstrom': 1, + 'bohr': angstrom2bohr, + }, + 'bohr': { + 'bohr': 1, + 'angstrom': bohr2angstrom, + }, + 'Ha': { + 'Ha': 1, + 'eV': Ha2eV, + 'Ry': Ha2Ry, + }, + 'Ry': { + 'Ha': 1./2, + 'Ry': 1, + 'eV': Ry2eV + } +} + + +def convert_units(a, from_units, to_units, inverse=False, copy=True): + """ + Convert units using predefined unit names and conversion factors + + Parameters + ---------- + a : array_like + from_units : str + from unit + to_units : str + to unit + copy : bool, optional + flag to copy the input array, or convert the units in-place (copy=False) + + """ + a = np.asarray(a) + if copy: + a = np.copy(a) + + if inverse: + a *= 1./_conversion_factors[from_units][to_units] + else: + a *= _conversion_factors[from_units][to_units] + + return a + + +def crystal2cartesian(coords, lattice_vectors): + """ + Convert vectors expressed in crystal coordinates to Cartesian coordinates + + Parameters + ---------- + coords: ndarray, shape (..., npts, 3) + lattice_vectors: ndarray, shape (3, 3) + + Returns + ------- + ndarray + + """ + + return np.dot(coords, lattice_vectors) + + +def cartesian2crystal(coords, lattice_vectors): + """ + Convert vectors expressed in Cartesian coordinates to crystal coordinates + + Parameters + ---------- + coords: ndarray, shape (..., npts, 3) + lattice_vectors: ndarray, shape (3, 3) + + Returns + ------- + ndarray + + """ + + return np.dot(coords, np.linalg.inv(lattice_vectors)) + + +def bweights(bvectors): + distances = np.sqrt(np.sum(bvectors[0]**2, axis=1)) + shells = np.unique(np.round(distances, decimals=6)) + bvec_shells = [bvectors[0][abs(distances - shell) < 5e-7] for shell in shells] + + amat = np.zeros((6, len(shells))) + for (ish, bvecs) in enumerate(bvec_shells): + bmat = np.array([np.einsum('i,j->ij', x, y) for x, y in zip(bvecs, bvecs)]) + bmat = np.sum(bmat, axis=0) + amat[:, ish] = bmat[np.triu_indices(3)] + + [U, s, V] = np.linalg.svd(amat, full_matrices=False) + bweights_tmp = reduce(np.dot, [np.transpose(V), np.diag(1.0/s), np.transpose(U), np.eye(3)[np.triu_indices(3)]]) + bweights = np.concatenate(tuple([np.ones(len(bvec_shells[ish]))*bweights_tmp[ish] for ish in range(len(shells))])) + + return bweights diff --git a/w90utils/io/nnkp.py b/w90utils/io/nnkp.py new file mode 100644 index 0000000..2dfcc23 --- /dev/null +++ b/w90utils/io/nnkp.py @@ -0,0 +1,134 @@ +"""Wannier90 I/O routines pertaining to NNKP files""" +from __future__ import absolute_import, division, print_function +import re + +import numpy as np + +from ._orbitals import orbitals +from . import _utils + + +def read_dlv(fname, units='bohr'): + pattern = re.compile(r'(?:begin\s+real_lattice)(.+)(?:end\s+real_lattice)', re.IGNORECASE | re.DOTALL) + with open(fname, 'r') as f: + match = pattern.search(f.read()) + if match is None: + raise Exception + + dlv = np.fromstring(match.group(1), sep='\n').reshape((3, 3)) + + dlv = _utils.convert_units(dlv, 'angstrom', units) + + return dlv + + +def read_rlv(fname, units='bohr'): + pattern = re.compile(r'(?:begin\s+recip_lattice)(.+)(?:end\s+recip_lattice)', re.IGNORECASE | re.DOTALL) + with open(fname, 'r') as f: + match = pattern.search(f.read()) + if match is None: + raise Exception + + rlv = np.fromstring(match.group(1), sep='\n').reshape((3, 3)) + + rlv = _utils.convert_units(rlv, 'angstrom', units, inverse=True) + + return rlv + + +def read_kpoints(fname, units='crystal'): + pattern = re.compile(r'(?:begin\s+kpoints)(?:\s+(?P[0-9]+)\s+)(?P.+)(?:end\s+kpoints)', re.IGNORECASE | re.DOTALL) + with open(fname, 'r') as f: + match = pattern.search(f.read()) + if match is None: + raise Exception + + nkpts = int(match.group('nkpts')) + kpoints = np.fromstring(match.group('kpoints'), sep='\n').reshape((nkpts, 3)) + + if units == 'crystal': + pass + elif units == 'angstrom' or units == 'bohr': + rlv = read_rlv(fname, units) + kpoints = np.dot(kpoints, rlv) + + return kpoints + + +def read_projections(fname): + pattern = re.compile(r'(?:begin\s+(?Pspinor_)?projections)(?:\s+(?P[0-9]+)\s+)(?P.+)(?:end\s+(?P=spinor)?projections)', re.IGNORECASE | re.DOTALL) + with open(fname, 'r') as f: + match = pattern.search(f.read()) + if match is None: + raise Exception + + nproj = int(match.group('nproj')) + spinors = match.group('spinor') is not None + raw_data = np.fromstring(match.group('projections'), sep='\n') + + if not spinors: + raw_data = np.reshape(raw_data, (nproj, 13)) + else: + raw_data = np.reshape(raw_data, (nproj, 17)) + + # create list of projections + # each projection is a dictionary + projections = [] + for iproj in range(len(raw_data)): + proj = {} + proj['center'] = raw_data[iproj][:3] + proj['l'] = l = int(raw_data[iproj][3]) + proj['mr'] = mr = int(raw_data[iproj][4]) + proj['r'] = int(raw_data[iproj][5]) + proj['z-axis'] = raw_data[iproj][6:9] + proj['x-axis'] = raw_data[iproj][9:12] + proj['zona'] = raw_data[iproj][12] + proj['spin'] = int(raw_data[iproj][13]) if spinors else None + proj['spin-axis'] = raw_data[iproj][14:] if spinors else None + proj['orbital'] = orbitals[l][mr] + + projections.append(proj) + + return projections + + +def read_nnkpts(fname): + pattern = re.compile(r'(?:begin\s+nnkpts)(?:\s+(?P[0-9]+)\s+)(?P.+)(?:end\s+nnkpts)', re.IGNORECASE | re.DOTALL) + with open(fname, 'r') as f: + match = pattern.search(f.read()) + if match is None: + raise Exception + + nntot = int(match.group('nntot')) + raw_data = np.fromstring(match.group('nnkpts'), sep='\n', dtype=int).reshape((-1, 5)) + + kpb_kidx = raw_data[:, 1].reshape((-1, nntot)) - 1 + kpb_g = raw_data[:, 2:].reshape((-1, nntot, 3)) + + return kpb_kidx, kpb_g + + +def read_bvectors(fname, units='angstrom'): + rlv = read_rlv(fname, units=units) + kpoints = read_kpoints(fname) + kpb_kidx, kpb_g = read_nnkpts(fname) + + kpb = kpoints[kpb_kidx] + + bvectors = kpb + kpb_g - kpoints[:, np.newaxis, :] + bvectors = np.einsum('kbi,ij->kbj', bvectors, rlv) + + return bvectors + + +def read_excluded_bands(fname): + pattern = re.compile(r'(?:begin\s+exclude_bands)(.+)(?:end\s+exclude_bands)', re.IGNORECASE | re.DOTALL) + with open(fname, 'r') as f: + match = pattern.search(f.read()) + if match is None: + raise Exception + + bnd_idx = np.fromstring(match.group(1), sep='\n')[1:] + bnd_idx -= 1 + + return bnd_idx diff --git a/w90utils/io/postw90.py b/w90utils/io/postw90.py new file mode 100644 index 0000000..b749baf --- /dev/null +++ b/w90utils/io/postw90.py @@ -0,0 +1,117 @@ +from __future__ import absolute_import, division, print_function +import sys + +import numpy as np + + +def read_kpoints(fname): + raw_data = np.loadtxt(fname, skiprows=3) + kpoints = raw_data[:, (1, 2, 3)] + + return kpoints + + +def write_kpoints(fname, kpoints): + with open(fname, 'w') as f: + print_kpoints(kpoints, file=f) + + +def print_kpoints(kpoints, header='', units='crystal', file=sys.stdout): + print(header, file=file) + print('%s' % units, file=file) + print(len(kpoints), file=file) + for (ik, kpt) in enumerate(kpoints): + print('%5d%18.12f%18.12f%18.12f' % tuple([ik+1] + list(kpt)), file=file) + + +def read_bands_kpoints(fname): + """ + Read k-points from the geninterp dat file + + Parameters + ---------- + fname : str + + Returns + ------- + kpoints : ndarray, shape (nkpts, 3) + array of kpoints using for geninterp, in of units :math:`\frac{1}{\text{\AA}}` + + """ + raw_data = np.loadtxt(fname) + + nkpts = int(np.max(raw_data[:, 0])) + nbnds = len(raw_data) // nkpts + + kpoints = raw_data[np.arange(0, len(raw_data), nbnds), 1:4] + + return kpoints + + +def read_bands(fname): + raw_data = np.loadtxt(fname) + nkpts = np.max(raw_data[:, 0]) + bands = raw_data[:, 4] + bands = bands.reshape((nkpts, -1)) + + return bands + + +def read_band_velocities(fname): + raw_data = np.loadtxt(fname) + + if not raw_data.shape[1] > 5: + return None + + nkpts = raw_data[-1, 0] + vnk = raw_data[:, 5:] + vnk = vnk.reshape((nkpts, -1, 3)) + + return vnk + + +read_vnk = read_band_velocities + + +def _read_boltzwann_data(fname): + raw_data = np.loadtxt(fname) + + if raw_data.ndim == 1: + nT = 1 + nmu = 1 + raw_data = raw_data.reshape((1, -1)) + mu, T = raw_data[:, (0, 1)].T + else: + mu, T = raw_data[:, (0, 1)].T + + dT = T[1] - T[0] + nT = int(np.rint((np.max(T) - np.min(T))/dT))+1 + T = T[:nT] + + mu = mu.reshape((-1, nT))[:, :1] + nmu = len(mu) + + raw_data = raw_data.reshape((nmu, nT, 8)) + + data = { + 'xx': raw_data[:, :, 2], + 'xy': raw_data[:, :, 3], + 'yy': raw_data[:, :, 4], + 'xz': raw_data[:, :, 5], + 'yz': raw_data[:, :, 6], + 'zz': raw_data[:, :, 7], + } + + return data, mu, T + + +def read_elcond(fname): + return _read_boltzwann_data(fname) + + +def read_dos(fname): + data = np.loadtxt(fname) + e = data[:, 0] + dos = data[:, 1] + + return e, dos diff --git a/w90utils/io/utils.py b/w90utils/io/utils.py new file mode 100644 index 0000000..b3670be --- /dev/null +++ b/w90utils/io/utils.py @@ -0,0 +1,13 @@ +from __future__ import absolute_import, division, print_function +import re + + +# def unk_get_ikpt_ispn(fname): +# p = re.compile(r'UNK(?P\d+)[.](?P(1|2))') +# match = p.match(fname) +# if match is None: +# return None +# else: +# ikpt = int(match.group('ikpt')) +# ispn = int(match.group('ispn')) +# return ikpt-1, ispn-1 diff --git a/w90utils/io/win.py b/w90utils/io/win.py new file mode 100644 index 0000000..d088838 --- /dev/null +++ b/w90utils/io/win.py @@ -0,0 +1,190 @@ +"""Wannier90 I/O routines pertaining to WIN files""" +from __future__ import absolute_import, division, print_function +import sys +import re + +import numpy as np + +from . import _utils + + +unit_cell_regex = re.compile( + r'BEGIN\s+UNIT_CELL_CART\s+' + r'(?PBOHR|ANG)?' + r'(?P.+)' + r'END\s+UNIT_CELL_CART\s+', + re.VERBOSE | re.IGNORECASE | re.DOTALL + ) + +atoms_regex = re.compile( + r'BEGIN\s+ATOMS_(?P(FRAC)|(CART))\s+' + r'(?PBOHR|ANG)?' + r'(?P.+)' + r'END\s+ATOMS_(?P=suffix)\s+', + re.VERBOSE | re.IGNORECASE | re.DOTALL + ) + +kpoints_regex = re.compile( + r'BEGIN\s+KPOINTS\s+' + r'(?P.+)' + r'END\s+KPOINTS\s+', + re.VERBOSE | re.IGNORECASE | re.DOTALL + ) + +kgrid_regex = re.compile( + r'MP_GRID\s+(?P\d+)\s+(?P\d+)\s+(?P\d+)\s+', + re.VERBOSE | re.IGNORECASE | re.DOTALL + ) + + +def read_dlv(fname, units='bohr'): + """ + Read direct lattice vectors from WIN file. + + Parameters + ---------- + fname : str + Wannier90 WIN file + units : str, {'bohr', 'angstrom'} + units of returned lattice vectors + + Returns + ------- + dlv : ndarray, shape (3, 3) + direct lattice vectors + + """ + with open(fname, 'r') as f: + match = unit_cell_regex.search(f.read()) + if match is None: + raise Exception + + dlv = np.fromstring(match.group('dlv').strip(), sep='\n').reshape((3, 3)) + + if match.group('units') is not None: + units_win = {'ANG': 'angstrom', 'BOHR': 'bohr'}[match.group('units').upper()] + else: + units_win = 'angstrom' + + if units == units_win: + pass + elif units in ['bohr', 'angstrom'] and units_win in ['bohr', 'angstrom']: + dlv = _utils.convert_units(dlv, units_win, units) + else: + raise Exception + + return dlv + + +def read_atoms(fname, units='crystal'): + with open(fname, 'r') as f: + match = atoms_regex.search(f.read()) + if match is None: + raise Exception + + symbols = [] + taus = [] + for line in match.group('atoms').strip().splitlines(): + symbols.append(line.split()[0]) + taus.append(np.array(map(float, line.split()[1:]))) + + if match.group('suffix').upper() == 'FRAC': + units_win = 'crystal' + else: + if match.group('units') is not None: + units_win = {'ANG': 'angstrom', 'BOHR': 'bohr'}[match.group('units').upper()] + else: + units_win = 'angstrom' + + taus = np.asarray(taus) + + if units == units_win: + pass + elif units == 'crystal' and units_win in ['bohr', 'angstrom']: + dlv = read_dlv(fname, units=units_win) + taus = _utils.cartesian2crystal(taus, dlv) + elif units in ['bohr', 'angstrom'] and units_win == 'crystal': + dlv = read_dlv(fname, units=units) + taus = _utils.crystal2cartesian(taus, dlv) + elif units in ['bohr', 'angstrom'] and units_win in ['bohr', 'angstrom']: + taus = _utils.convert_units(taus, units_win, units) + else: + raise Exception + + basis = list(zip(symbols, taus)) + + return basis + + +def read_kgrid(fname): + with open(fname, 'r') as f: + match = kgrid_regex.search(f.read()) + if match is None: + raise Exception + + kgrid = (int(match.group('nk1')), int(match.group('nk2')), int(match.group('nk3'))) + + return kgrid + + +def read_kpoints(fname): + with open(fname, 'r') as f: + match = kpoints_regex.search(f.read()) + if match is None: + raise Exception + + kpoints = np.fromstring(match.group('kpoints').strip(), sep='\n').reshape((-1, 3)) + + return kpoints + + +def print_unit_cell(dlv, units='bohr', file=sys.stdout): + units = units.upper() + + print('BEGIN UNIT_CELL_CART', file=file) + # + if units == 'BOHR' or 'ANG': + print(units, file=file) + else: + raise ValueError('units must be "bohr" or "ang"') + # + np.savetxt(file, dlv, fmt='%18.12f') + # + print('END UNIT_CELL_CART', file=file) + print('', file=file) + + +def print_atoms(atoms, units='crystal', file=sys.stdout): + units = units.upper() + + if units == 'CRYSTAL': + block_label = 'ATOMS_FRAC' + print('BEGIN ATOMS_FRAC', file=file) + elif units == 'BOHR' or units == 'ANG': + block_label = 'ATOMS_CART' + print('BEGIN ATOMS_CART', file=file) + print(units, file=file) + else: + raise ValueError('units must be "CRYSTAL", "BOHR", or "ANG"') + # + for symbol, tau in atoms: + print('%-5s ' % symbol, end='', file=file) + np.savetxt(file, np.asarray(tau).reshape((1, 3)), fmt='%18.12f') + # + print('END %s' % block_label, file=file) + print('', file=file) + + +def print_kgrid(kgrid, file=sys.stdout): + print('MP_GRID %3d %3d %3d' % tuple(kgrid), file=file) + + +def print_kpoints(kpoints, mp_grid=None, file=sys.stdout): + if mp_grid is not None: + nkpts = np.prod(mp_grid) + if nkpts != len(kpoints): + raise Exception + print('MP_GRID %3d %3d %3d' % tuple(mp_grid), file=file) + print('BEGIN KPOINTS', file=file) + np.savetxt(file, kpoints, fmt='%18.12f') + print('END KPOINTS', file=file) diff --git a/w90utils/io/wout.py b/w90utils/io/wout.py new file mode 100644 index 0000000..db4282e --- /dev/null +++ b/w90utils/io/wout.py @@ -0,0 +1,50 @@ +"""Wannier90 I/O routines pertaining to WOUT files""" +from __future__ import absolute_import, division, print_function + +import numpy as np + + +def read_centers_xyz(fname): + with open(fname, 'r') as f: + contents = f.readlines() + + centers = [] + for line in contents[2:]: + symbol = line.split()[0] + tau = np.array(map(float, line.split()[1:])) + + centers.append((symbol, tau)) + + return centers + + +def read_conv(fname): + with open(fname, 'r') as f: + lines = f.readlines() + + conv_data = [] + data_lines = [] + for line in lines: + if line.strip().endswith('CONV'): + data_lines.append(line) + + conv_data = np.array([map(float, line.split()[:4]) for line in data_lines[3:]]) + conv_data = dict(zip(['iter', 'delta', 'gradient', 'spread', 'time'], conv_data.T)) + conv_data['iter'] = conv_data['iter'].astype(int) + + return conv_data + + +def read_sprd(fname): + with open(fname, 'r') as f: + lines = f.readlines() + + sprd_data = {'D': [], 'OD': [], 'TOT': []} + for line in lines: + if line.strip().endswith('SPRD'): + data = line.split() + sprd_data['D'].append(float(data[1])) + sprd_data['OD'].append(float(data[3])) + sprd_data['TOT'].append(float(data[5])) + + return sprd_data diff --git a/w90utils/sprd.py b/w90utils/sprd.py new file mode 100644 index 0000000..7a7a589 --- /dev/null +++ b/w90utils/sprd.py @@ -0,0 +1,160 @@ +"""Functions for computing components of the spread""" +from __future__ import absolute_import, division, print_function + +import numpy as np + + +def wannier_centers(m, bvectors, bweights): + (nkpts, nntot, nwann) = m.shape[:-1] + + mii = m.reshape((nkpts*nntot, nwann, nwann)).diagonal(offset=0, axis1=1, axis2=2) + + bvectors = np.reshape(bvectors, (-1, 3)) + bweights = np.tile(bweights, nkpts) + + bwv = bweights[:, np.newaxis] * bvectors + + # Eq. 31 + rv = np.zeros((nwann, 3)) + for i in range(nwann): + rv[i] = -1 * np.sum(bwv * np.imag(np.log(mii[:, i]))[:, np.newaxis], axis=0) + rv /= nkpts + + return rv + + +def omega_d(m, bvectors, bweights, idx=None): # Eq. 36 + """ + Compute the diagonal contribution to the spread functional + + Parameters + ---------- + m: ndarray, shape (nkpts, nntot, nbnds, nbnds) + the overlap matrix + bvectors: ndarray, shape (nkpts, nntot, 3) + bweights: ndarray, shape (nntot,) + + """ + (nkpts, nntot, nwann) = m.shape[:-1] + + rv = wannier_centers(m, bvectors, bweights) + + mii = np.copy(m.reshape((nkpts*nntot, nwann, nwann)).diagonal(offset=0, axis1=1, axis2=2)) + + bvectors = np.reshape(bvectors, (-1, 3)) + bweights = np.tile(bweights, nkpts) + + bvrv = np.einsum('bi,ri->br', bvectors, rv) + + if idx is not None: + sprd_d = np.sum(bweights[:, np.newaxis] * (-1 * np.imag(np.log(mii)) - bvrv)**2, axis=0) / nkpts + sprd_d = sprd_d[idx] + else: + sprd_d = np.sum(bweights[:, np.newaxis] * (-1 * np.imag(np.log(mii)) - bvrv)**2) / nkpts + + return sprd_d + + +def omega_od(Mmn, bweights): # Eq. 35 + """ + Compute the off-diagonal contribution to the spread functional + + Parameters + ---------- + Mmn: ndarray, shape (nkpts, nntot, nbnds, nbnds) + the overlap matrix + bweights: ndarray, shape (nntot,) + + """ + (nkpts, nntot, nbnds) = Mmn.shape[:-1] + + m = Mmn.reshape((nkpts*nntot, nbnds, nbnds)) + mii = m.diagonal(offset=0, axis1=1, axis2=2) + + bweights = np.tile(bweights, nkpts) + + sprd_od = np.sum(bweights[:, np.newaxis, np.newaxis] * np.abs(m)**2) + sprd_od -= np.sum(bweights[:, np.newaxis] * np.abs(mii)**2) + sprd_od /= nkpts + + return sprd_od + + +def omega_dod(Mmn, bvectors, bweights): + """ + Compute the spread functional + + Parameters + ---------- + Mmn: ndarray, shape (nkpts, nntot, nbnds, nbnds) + the overlap matrix + bvectors: ndarray, shape (nkpts, nntot, 3) + bweights: ndarray, shape (nntot,) + + """ + sprd_d = omega_d(Mmn, bvectors, bweights) + sprd_od = omega_od(Mmn, bweights) + + return sprd_d + sprd_od + + +def omega_iod(m, bweights, idx=None): # Eq. 43 + """ + Compute the sum of the invariant and off-diagonal contribution to the spread + functional + + Parameters + ---------- + m: ndarray, shape (nkpts, nntot, nbnds, nbnds) + the overlap matrix + bweights: ndarray, shape (nntot,) + + """ + (nkpts, nntot, nwann) = m.shape[:-1] + + mii = m.reshape((nkpts*nntot, nwann, nwann)).diagonal(offset=0, axis1=1, axis2=2) + + bweights = np.tile(bweights, nkpts) + + if idx is not None: + sprd_iod = np.sum(bweights[:, np.newaxis] * (1 - np.abs(mii)**2), axis=0) / nkpts + sprd_iod = sprd_iod[idx] + else: + sprd_iod = np.sum(bweights[:, np.newaxis] * (1 - np.abs(mii)**2)) / nkpts + + return sprd_iod + + +def omega_i(Mmn, bweights): + """ + Compute the invariant and off-diagonal contribution to the spread functional + + Parameters + ---------- + Mmn: ndarray, shape (nkpts, nntot, nbnds, nbnds) + the overlap matrix + bweights: ndarray, shape (nntot,) + + """ + sprd_od = omega_od(Mmn, bweights) + sprd_iod = omega_iod(Mmn, bweights) + + return sprd_iod - sprd_od + + +def omega(Mmn, bvectors, bweights): + """ + Compute the spread functional + + Parameters + ---------- + Mmn: ndarray, shape (nkpts, nntot, nbnds, nbnds) + the overlap matrix + bvectors: ndarray, shape (nkpts, nntot, 3) + bweights: ndarray, shape (nntot,) + + """ + sprd_d = omega_d(Mmn, bvectors, bweights) + sprd_iod = omega_iod(Mmn, bweights) + + return sprd_d + sprd_iod