from __future__ import annotations
from functools import cached_property
from typing import TYPE_CHECKING, Callable, Generator, Generic, TypeVar
import numpy as np
from ase.io.ulm import Writer
from ase.units import Bohr, Ha
from gpaw.gpu import as_np, synchronize
from gpaw.gpu.mpi import CuPyMPI
from gpaw.mpi import MPIComm, serial_comm
from gpaw.new import zips
from gpaw.new.brillouin import IBZ
from gpaw.new.c import GPU_AWARE_MPI
from gpaw.new.potential import Potential
from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions
from gpaw.new.wave_functions import WaveFunctions
from gpaw.typing import Array1D, Array2D, Self
from gpaw.utilities import pack_density
if TYPE_CHECKING:
from gpaw.new.density import Density
WFT = TypeVar('WFT', bound=WaveFunctions)
[docs]
class IBZWaveFunctions(Generic[WFT]):
def __init__(self,
ibz: IBZ,
*,
nelectrons: float,
ncomponents: int,
wfs_qs: list[list[WFT]],
kpt_comm: MPIComm = serial_comm,
kpt_band_comm: MPIComm = serial_comm,
comm: MPIComm = serial_comm):
"""Collection of wave function objects for k-points in the IBZ."""
self.ibz = ibz
self.kpt_comm = kpt_comm
self.kpt_band_comm = kpt_band_comm
self.comm = comm
self.nelectrons = nelectrons
self.ncomponents = ncomponents
self.collinear = (ncomponents != 4)
self.spin_degeneracy = ncomponents % 2 + 1
self.nspins = ncomponents % 3
self.rank_k = ibz.ranks(kpt_comm)
self.wfs_qs = wfs_qs
self.q_k = {} # IBZ-index to local index
for wfs in self:
self.q_k[wfs.k] = wfs.q
self.band_comm = wfs.band_comm
self.domain_comm = wfs.domain_comm
self.dtype = wfs.dtype
self.nbands = wfs.nbands
self.fermi_levels: Array1D | None = None # hartree
self.xp = self.wfs_qs[0][0].xp
if self.xp is not np:
if not GPU_AWARE_MPI:
self.kpt_comm = CuPyMPI(self.kpt_comm) # type: ignore
self.move_wave_functions: Callable[..., None] = lambda *args: None
self.read_from_file_init_wfs_dm = False
[docs]
@classmethod
def create(cls,
*,
ibz: IBZ,
nelectrons: float,
ncomponents: int,
create_wfs_func,
kpt_comm: MPIComm = serial_comm,
kpt_band_comm: MPIComm = serial_comm,
comm: MPIComm = serial_comm,
) -> Self:
"""Collection of wave function objects for k-points in the IBZ."""
rank_k = ibz.ranks(kpt_comm)
mask_k = (rank_k == kpt_comm.rank)
k_q = np.arange(len(ibz))[mask_k]
nspins = ncomponents % 3
wfs_qs: list[list[WFT]] = []
for q, k in enumerate(k_q):
wfs_s = []
for spin in range(nspins):
wfs = create_wfs_func(spin, q, k,
ibz.kpt_kc[k], ibz.weight_k[k])
wfs_s.append(wfs)
wfs_qs.append(wfs_s)
return cls(ibz,
nelectrons=nelectrons,
ncomponents=ncomponents,
wfs_qs=wfs_qs,
kpt_comm=kpt_comm,
kpt_band_comm=kpt_band_comm,
comm=comm)
@cached_property
def mode(self):
wfs = self.wfs_qs[0][0]
if isinstance(wfs, PWFDWaveFunctions):
if hasattr(wfs.psit_nX.desc, 'ecut'):
return 'pw'
return 'fd'
return 'lcao'
[docs]
def has_wave_functions(self):
return True
[docs]
def get_max_shape(self, global_shape: bool = False) -> tuple[int, ...]:
"""Find the largest wave function array shape.
For a PW-calculation, this shape could depend on k-point.
"""
if global_shape:
shape = np.array(max(wfs.array_shape(global_shape=True)
for wfs in self))
self.kpt_comm.max(shape)
return tuple(shape)
return max(wfs.array_shape() for wfs in self)
@property
def fermi_level(self) -> float:
fl = self.fermi_levels
assert fl is not None and len(fl) == 1
return fl[0]
def __str__(self):
shape = self.get_max_shape(global_shape=True)
wfs = self.wfs_qs[0][0]
nbytes = (len(self.ibz) *
self.nbands *
len(self.wfs_qs[0]) *
wfs.bytes_per_band)
ncores = (self.kpt_comm.size *
self.domain_comm.size *
self.band_comm.size)
return (f'{self.ibz.symmetries}\n'
f'{self.ibz}\n'
f'{wfs._short_string(shape)}\n'
f'spin-components: {self.ncomponents}'
' # (' +
('' if self.collinear else 'non-') + 'collinear spins)\n'
f'bands: {self.nbands}\n'
f'valence electrons: {self.nelectrons}\n'
f'spin-degeneracy: {self.spin_degeneracy}\n'
f'dtype: {self.dtype}\n\n'
'memory:\n'
f' storage: {"CPU" if self.xp is np else "GPU"}\n'
f' wave functions: {nbytes:_} # bytes '
f' ({nbytes // ncores:_} per core)\n\n'
'parallelization:\n'
f' kpt: {self.kpt_comm.size}\n'
f' domain: {self.domain_comm.size}\n'
f' band: {self.band_comm.size}\n')
def __iter__(self) -> Generator[WFT, None, None]:
for wfs_s in self.wfs_qs:
yield from wfs_s
[docs]
def move(self, relpos_ac, atomdist):
self.ibz.symmetries.check_positions(relpos_ac)
self.make_sure_wfs_are_read_from_gpw_file()
for wfs in self:
wfs.move(relpos_ac, atomdist, self.move_wave_functions)
[docs]
def orthonormalize(self, work_array_nX: np.ndarray = None):
for wfs in self:
wfs.orthonormalize(work_array_nX)
[docs]
def calculate_occs(self,
occ_calc,
fix_fermi_level=False) -> tuple[float, float, float]:
degeneracy = self.spin_degeneracy
# u index is q and s combined
occ_un, fermi_levels, e_entropy = occ_calc.calculate(
nelectrons=self.nelectrons / degeneracy,
eigenvalues=[wfs.eig_n * Ha for wfs in self],
weights=[wfs.weight for wfs in self],
fermi_levels_guess=(None
if self.fermi_levels is None else
self.fermi_levels * Ha),
fix_fermi_level=fix_fermi_level)
if not fix_fermi_level:
self.fermi_levels = np.array(fermi_levels) / Ha
else:
assert self.fermi_levels is not None
for occ_n, wfs in zips(occ_un, self):
wfs._occ_n = occ_n
e_entropy *= degeneracy / Ha
e_band = 0.0
for wfs in self:
e_band += wfs.occ_n @ wfs.eig_n * wfs.weight * degeneracy
e_band = self.kpt_comm.sum_scalar(float(e_band)) # XXX CPU float?
return e_band, e_entropy, e_entropy * occ_calc.extrapolate_factor
[docs]
def add_to_density(self, nt_sR, D_asii) -> None:
"""Compute density and add to ``nt_sR`` and ``D_asii``."""
for wfs in self:
wfs.add_to_density(nt_sR, D_asii)
if self.xp is not np:
synchronize()
# This should be done in a more efficient way!!!
# Also: where do we want the density?
self.kpt_comm.sum(nt_sR.data)
self.kpt_comm.sum(D_asii.data)
self.band_comm.sum(nt_sR.data)
self.band_comm.sum(D_asii.data)
[docs]
def normalize_density(self, density: Density) -> None:
pass # overwritten in LCAOIBZWaveFunctions class
[docs]
def add_to_ked(self, taut_sR) -> None:
for wfs in self:
wfs.add_to_ked(taut_sR)
if self.xp is not np:
synchronize()
self.kpt_comm.sum(taut_sR.data)
self.band_comm.sum(taut_sR.data)
[docs]
def get_all_electron_wave_function(self,
band,
kpt=0,
spin=0,
grid_spacing=0.05,
skip_paw_correction=False):
wfs = self.get_wfs(kpt=kpt, spin=spin, n1=band, n2=band + 1)
if wfs is None:
return None
assert isinstance(wfs, PWFDWaveFunctions)
psit_X = wfs.psit_nX[0].to_pbc_grid()
grid = psit_X.desc.uniform_grid_with_grid_spacing(grid_spacing)
psi_r = psit_X.interpolate(grid=grid)
if not skip_paw_correction:
dphi_aj = wfs.setups.partial_wave_corrections()
dphi_air = grid.atom_centered_functions(dphi_aj, wfs.relpos_ac)
dphi_air.add_to(psi_r, wfs.P_ani[:, 0])
return psi_r
[docs]
def get_wfs(self,
*,
kpt: int = 0,
spin: int = 0,
n1=0,
n2=0):
rank = self.rank_k[kpt]
if rank == self.kpt_comm.rank:
wfs = self.wfs_qs[self.q_k[kpt]][spin]
wfs2 = wfs.collect(n1, n2)
if rank == 0:
return wfs2
if wfs2 is not None:
wfs2.send(0, self.kpt_comm)
return
if self.comm.rank == 0:
return self.wfs_qs[0][0].receive(rank, self.kpt_comm)
return None
[docs]
def get_eigs_and_occs(self, k=0, s=0):
if self.domain_comm.rank == 0 and self.band_comm.rank == 0:
rank = self.rank_k[k]
if rank == self.kpt_comm.rank:
wfs = self.wfs_qs[self.q_k[k]][s]
if rank == 0:
return wfs._eig_n, wfs._occ_n
self.kpt_comm.send(wfs._eig_n, 0)
self.kpt_comm.send(wfs._occ_n, 0)
elif self.kpt_comm.rank == 0:
eig_n = np.empty(self.nbands)
occ_n = np.empty(self.nbands)
self.kpt_comm.receive(eig_n, rank)
self.kpt_comm.receive(occ_n, rank)
return eig_n, occ_n
return np.zeros(0), np.zeros(0)
[docs]
def get_all_eigs_and_occs(self, broadcast=False):
nkpts = len(self.ibz)
mynbands = self.nbands if self.comm.rank == 0 or broadcast else 0
eig_skn = np.empty((self.nspins, nkpts, mynbands))
occ_skn = np.empty((self.nspins, nkpts, mynbands))
for k in range(nkpts):
for s in range(self.nspins):
eig_n, occ_n = self.get_eigs_and_occs(k, s)
if self.comm.rank == 0:
eig_skn[s, k, :] = eig_n
occ_skn[s, k, :] = occ_n
if broadcast:
self.comm.broadcast(eig_skn, 0)
self.comm.broadcast(occ_skn, 0)
return eig_skn, occ_skn
[docs]
def forces(self, potential: Potential) -> Array2D:
self.make_sure_wfs_are_read_from_gpw_file()
F_av = self.xp.zeros((len(potential.dH_asii), 3))
for wfs in self:
wfs.force_contribution(potential, F_av)
if self.xp is not np:
synchronize()
self.kpt_band_comm.sum(F_av)
return F_av
[docs]
def write(self, writer: Writer, flags) -> None:
"""Write fermi-level(s), eigenvalues, occupation numbers, ...
... k-points, symmetry information, projections and possibly
also the wave functions.
"""
eig_skn, occ_skn = self.get_all_eigs_and_occs()
if not self.collinear:
eig_skn = eig_skn[0]
occ_skn = occ_skn[0]
assert self.fermi_levels is not None
writer.write(fermi_levels=self.fermi_levels * Ha,
eigenvalues=eig_skn * Ha,
occupations=occ_skn)
ibz = self.ibz
writer.child('kpts').write(
atommap=ibz.symmetries.atommap_sa,
bz2ibz=ibz.bz2ibz_K,
bzkpts=ibz.bz.kpt_Kc,
ibzkpts=ibz.kpt_kc,
rotations=ibz.symmetries.rotation_scc,
translations=ibz.symmetries.translation_sc,
weights=ibz.weight_k)
nproj = self.wfs_qs[0][0].P_ani.layout.size
spin_k_shape: tuple[int, ...]
proj_shape: tuple[int, ...]
if self.collinear:
spin_k_shape = (self.ncomponents, len(ibz))
proj_shape = (self.nbands, nproj)
else:
spin_k_shape = (len(ibz),)
proj_shape = (self.nbands, 2, nproj)
if flags.include_projections:
proj_dtype = flags.storage_dtype(self.dtype)
writer.add_array('projections', spin_k_shape + proj_shape,
proj_dtype)
for spin in range(self.nspins):
for k, rank in enumerate(self.rank_k):
if rank == self.kpt_comm.rank:
wfs = self.wfs_qs[self.q_k[k]][spin]
P_ani = wfs.P_ani.to_cpu().gather() # gather atoms
if P_ani is not None:
P_nI = P_ani.matrix.gather() # gather bands
if P_nI.dist.comm.rank == 0:
if rank == 0:
writer.fill(P_nI.data.reshape(
proj_shape).astype(proj_dtype))
else:
self.kpt_comm.send(P_nI.data, 0)
elif self.comm.rank == 0:
data = np.empty(proj_shape, self.dtype)
self.kpt_comm.receive(data, rank)
writer.fill(data.astype(proj_dtype))
if flags.include_wfs:
self._write_wave_functions(writer, spin_k_shape, flags)
def _write_wave_functions(self, writer, spin_k_shape, flags):
# We collect all bands to master. This may have to be changed
# to only one band at a time XXX
xshape = self.get_max_shape(global_shape=True)
shape = spin_k_shape + (self.nbands,) + xshape
dtype = complex if self.mode == 'pw' else self.dtype
dtype_write = flags.storage_dtype(dtype)
c = 1.0 if self.mode == 'lcao' else Bohr**-1.5
writer.add_array('coefficients', shape, dtype=dtype_write)
buf_nX = np.empty((self.nbands,) + xshape, dtype=dtype)
for spin in range(self.nspins):
for k, rank in enumerate(self.rank_k):
if rank == self.kpt_comm.rank:
wfs = self.wfs_qs[self.q_k[k]][spin]
coef_nX = wfs.gather_wave_function_coefficients()
if coef_nX is not None:
coef_nX = as_np(coef_nX)
if self.mode == 'pw':
x = coef_nX.shape[-1]
if x < xshape[-1]:
# For PW-mode, we may need to zero-pad the
# plane-wave coefficient up to the maximum
# for all k-points:
buf_nX[..., :x] = coef_nX
buf_nX[..., x:] = 0.0
coef_nX = buf_nX
if rank == 0:
writer.fill(flags.to_storage_dtype(coef_nX * c))
else:
self.kpt_comm.send(coef_nX, 0)
elif self.comm.rank == 0:
self.kpt_comm.receive(buf_nX, rank)
writer.fill(flags.to_storage_dtype(buf_nX * c))
[docs]
def write_summary(self, log):
fl = self.fermi_levels * Ha
if len(fl) == 1:
log(f'\nFermi level: {fl[0]:.3f}')
else:
log(f'\nFermi levels: {fl[0]:.3f}, {fl[1]:.3f}')
ibz = self.ibz
eig_skn, occ_skn = self.get_all_eigs_and_occs()
if self.comm.rank != 0:
return
eig_skn *= Ha
D = self.spin_degeneracy
nbands = eig_skn.shape[2]
for k, (x, y, z) in enumerate(ibz.kpt_kc):
if k == 3:
log(f'(only showing first 3 out of {len(ibz)} k-points)')
break
log(f'\nkpt = [{x:.3f}, {y:.3f}, {z:.3f}], '
f'weight = {ibz.weight_k[k]:.3f}:')
if self.nspins == 1:
skipping = False
log(f' Band eig [eV] occ [0-{D}]')
eig_n = eig_skn[0, k]
n0 = (eig_n < fl[0]).sum() - 0.5
for n, (e, f) in enumerate(zips(eig_n, occ_skn[0, k])):
# First, last and +-8 bands window around Fermi level:
if n == 0 or abs(n - n0) < 8 or n == nbands - 1:
log(f' {n:4} {e:13.3f} {D * f:9.3f}')
skipping = False
else:
if not skipping:
log(' ...')
skipping = True
else:
log(' Band eig [eV] occ [0-1]'
' eig [eV] occ [0-1]')
for n, (e1, f1, e2, f2) in enumerate(zips(eig_skn[0, k],
occ_skn[0, k],
eig_skn[1, k],
occ_skn[1, k])):
log(f' {n:4} {e1:13.3f} {f1:9.3f}'
f' {e2:10.3f} {f2:9.3f}')
try:
from ase.dft.bandgap import GapInfo
except ImportError:
log('No gapinfo -- requires new ASE')
return
try:
log()
fermilevel = fl[0]
gapinfo = GapInfo(eigenvalues=eig_skn - fermilevel)
log(gapinfo.description(ibz_kpoints=ibz.kpt_kc))
except ValueError:
# Maybe we only have the occupied bands and no empty bands
log('Could not find a gap')
[docs]
def make_sure_wfs_are_read_from_gpw_file(self):
for wfs in self:
psit_nX = getattr(wfs, 'psit_nX', None)
if psit_nX is None:
return
if hasattr(psit_nX.data, 'fd'): # fd=file-descriptor
self.read_from_file_init_wfs_dm = True
psit_nX.data = psit_nX.data[:] # read
[docs]
def get_homo_lumo(self, spin: int = None) -> Array1D:
"""Return HOMO and LUMO eigenvalues."""
if self.ncomponents == 1:
assert spin != 1
spin = 0
elif self.ncomponents == 2:
if spin is None:
h0, l0 = self.get_homo_lumo(0)
h1, l1 = self.get_homo_lumo(1)
return np.array([max(h0, h1), min(l0, l1)])
else:
assert spin != 1
spin = 0
nocc = 0.0
for wfs_s in self.wfs_qs:
wfs = wfs_s[spin]
nocc += wfs.occ_n.sum() * wfs.weight
nocc = self.kpt_comm.sum_scalar(nocc)
n = int(round(nocc))
homo = -np.inf
if n > 0:
for wfs_s in self.wfs_qs:
homo = max(homo, wfs_s[spin].eig_n[n - 1])
homo = self.kpt_comm.max_scalar(homo)
lumo = np.inf
if n < self.nbands:
for wfs_s in self.wfs_qs:
lumo = min(lumo, wfs_s[spin].eig_n[n])
lumo = self.kpt_comm.min_scalar(lumo)
return np.array([homo, lumo])
[docs]
def calculate_kinetic_energy(self,
hamiltonian,
density: Density) -> float:
e_kin = 0.0
for wfs in self:
e_kin += hamiltonian.calculate_kinetic_energy(wfs, skip_sum=True)
e_kin = self.comm.sum_scalar(e_kin)
# PAW corrections:
e_kin_paw = 0.0
for a, D_sii in density.D_asii.items():
setup = wfs.setups[a]
D_p = pack_density(D_sii.real[:density.ndensities].sum(0))
e_kin_paw += setup.K_p @ D_p + setup.Kc
e_kin_paw = density.grid.comm.sum_scalar(e_kin_paw)
return e_kin + e_kin_paw