Source code for gpaw.new.pw.hybrids

from __future__ import annotations

from dataclasses import dataclass
from math import pi
from pathlib import Path
from time import time
from typing import IO, Callable, TYPE_CHECKING

import numpy as np
from ase.units import Ha
from gpaw.core import PWArray, PWDesc, UGArray, UGDesc
from gpaw.core.arrays import XArray
from gpaw.core.atom_arrays import AtomArrays
from gpaw.core.pwacf import PWAtomCenteredFunctions
from gpaw.mpi import broadcast
from gpaw.new import zips as zip
from gpaw.new.ibzwfs import IBZWaveFunctions
from gpaw.new.logger import Logger
from gpaw.new.pw.hamiltonian import PWHamiltonian
from gpaw.new.pwfd.ibzwfs import PWFDIBZWaveFunctions
from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions
from gpaw.new.xc import create_functional
from gpaw.setup import Setups
from gpaw.utilities import unpack_hermitian, pack_density
from gpaw.utilities.blas import mmm
from scipy.linalg.blas import get_blas_funcs
if TYPE_CHECKING:
    from gpaw.dft import DFT


@dataclass
class Psit:
    ut_nR: UGArray
    P_ani: AtomArrays
    f_n: np.ndarray
    kpt_c: np.ndarray
    Q_aniL: dict[int, np.ndarray]
    spin: int
    dP_anvi: AtomArrays | None = None  # used for forces


[docs] def truncated_coulomb(cell_cv, bz, omega: float = 0.11, yukawa: bool = False) -> Callable[[PWDesc], np.ndarray]: """Fourier transform of truncated Coulomb. For the yukawa=False case, we have in real space::: erfc(ωr) --------. r In reciprocal space::: 4π _ _ 2 2 ------(1 - exp(-(G+k) /(4 ω ))) _ _ 2 (G+k) (G+k=0 limit is pi/ω^2). """ from gpaw.hybrids.wstc import WignerSeitzTruncatedCoulomb if yukawa: return lambda pw: 2 * pi / (pw.ekin_G + 0.5 * omega**2) if omega != 0.0: def f(pw): G2_G = pw.ekin_G * 2 v_G = 4 * pi * (1 - np.exp(-G2_G / (4 * omega**2))) ok_G = G2_G > 1e-10 v_G[ok_G] /= G2_G[ok_G] v_G[~ok_G] = pi / omega**2 return v_G return f wstc = WignerSeitzTruncatedCoulomb(cell_cv, bz.size_c) return lambda pw: wstc.get_potential_new(pw)
def ibz2bz(ibzwfs: PWFDIBZWaveFunctions, setups: Setups, relpos_ac: np.ndarray, grid: UGDesc, plan, # FFT-plan log: Logger | None = None, forces: bool = False) -> tuple[list[Psit], int]: """Compute BZ from IBZ and distribute.""" log = log or Logger(None, None) nocc = ibzwfs.number_of_occupied_bands() nspins = ibzwfs.nspins ibz = ibzwfs.ibz log(ibz) log('Occupied bands:', nocc) log('Transforming wave functions from IBZ to BZ: ', end='') t1 = time() nbzk = len(ibz.bz) comm = ibzwfs.comm symmetries = ibzwfs.ibz.symmetries rank_Ks = np.zeros((nbzk, nspins), int) kpt_Kc = np.zeros((nbzk, 3)) psit_KsnG = {} for wfs1 in ibzwfs: wfs = wfs1.collect_bands_and_domain(0, nocc) if wfs is None: continue for K, k in enumerate(ibz.bz2ibz_K): if k != wfs.k: continue rank_Ks[K, wfs.spin] = comm.rank s = ibz.s_K[K] U_cc = symmetries.rotation_scc[s] complex_conjugate = ibz.time_reversal_K[K] assert isinstance(wfs, PWFDWaveFunctions) psit1_nG = wfs.psit_nX assert isinstance(psit1_nG, PWArray) psit2_nG = psit1_nG.transform(U_cc, complex_conjugate) if wfs.spin == 0: kpt_Kc[K] = psit2_nG.desc.kpt_c dk_c = psit2_nG.desc.kpt_c - ibz.bz.kpt_Kc[K] assert abs(dk_c - dk_c.round()).max() < 1e-8 psit_KsnG[(K, wfs.spin)] = psit2_nG comm.sum(rank_Ks) comm.sum(kpt_Kc) t2 = time() log(f'{t2 - t1:.3f} seconds') nocc_total = nocc * nbzk blocksize = (nocc_total + comm.size - 1) // comm.size blocks = [] for rank in range(comm.size): Ka, na = divmod(rank * blocksize, nocc) Kb, nb = divmod((rank + 1) * blocksize, nocc) for K in range(Ka, min(Kb, nbzk)): blocks.append((rank, K, (na, nocc))) na = 0 if nb > na and Kb < nbzk: blocks.append((rank, Kb, (na, nb))) log('Distributing wave functions and iFFT-ing to real space: ', end='') t1 = time() requests = [] for (K, spin), psit_nG in psit_KsnG.items(): for rank, KK, (na, nb) in blocks: if KK != K: continue if rank != comm.rank: requests.append( comm.send(psit_nG.data[na:nb], rank, block=False, tag=K * nspins + spin)) _, occ_skn = ibzwfs.get_all_eigs_and_occs(broadcast=True) pw = ibzwfs._wfs_u[0].psit_nX.desc.new(comm=None) mypsits = [] for rank, K, (na, nb) in blocks: if rank != comm.rank: continue pt_aiG = None for spin in range(nspins): if rank_Ks[K, spin] == rank: psit_nG = psit_KsnG[(K, spin)][na:nb] else: psit_nG = pw.new(kpt=kpt_Kc[K]).empty(nb - na) comm.receive(psit_nG.data, rank_Ks[K, spin], tag=K * nspins + spin) pt_aiG = pt_aiG or psit_nG.desc.atom_centered_functions( [setup.pt_j for setup in setups], relpos_ac) P_ani = pt_aiG.integrate(psit_nG) psit_nR = psit_nG.ifft(grid=grid, plan=plan, periodic=False) Q_aniL = {a: np.einsum('ijL, nj -> niL', setup.Delta_iiL, P_ani[a].conj()) for a, setup in enumerate(setups)} k = ibz.bz2ibz_K[K] f_n = occ_skn[spin, k, na:nb] psit = Psit(psit_nR, P_ani, f_n, psit_nG.desc.kpt_c, Q_aniL, spin) if forces: psit.dP_anvi = pt_aiG.derivative(psit_nG) mypsits.append(psit) comm.waitall(requests) t2 = time() log(f'{t2 - t1:.3f} seconds') return mypsits, nocc class PWHybridHamiltonian(PWHamiltonian): band_local = False def __init__(self, grid: UGDesc, pw: PWDesc, xc, setups: Setups, relpos_ac, atomdist, log, bz, kpt_comm, band_comm, comm): super().__init__(grid, pw.dtype) self.pw = pw self.exx_fraction = xc.exx_fraction self.xc = xc self.kpt_comm = kpt_comm self.band_comm = band_comm self.comm = comm self.log = log self.delta_aiiL = [setup.Delta_iiL for setup in setups] self.relpos_ac = relpos_ac self.setups = setups self.nbzk = len(bz) self.real = np.issubdtype(pw.dtype, np.floating) self.zaxpy = get_blas_funcs('axpy', dtype=complex) # Stuff for PAW core-core, core-valence and valence-valence correctios: self.exx_cc = sum(setup.ExxC for setup in setups) * self.exx_fraction self.VC_aii = [unpack_hermitian(setup.X_p * self.exx_fraction) for setup in setups] self.delta_aiiL = [setup.Delta_iiL for setup in setups] self.VV_app = [setup.M_pp * self.exx_fraction for setup in setups] # Globally distributed wave functions: self.mypsits: list[Psit] = [] # Cached potential for gamma-point calculation: self.coulomb = truncated_coulomb( grid.cell_cv, bz, xc.exx_omega, xc.exx_yukawa) self.nupdates = 0 self.devc = np.nan self.devv = np.nan self.evv = np.nan self.dekin = np.nan def update_wave_functions(self, ibzwfs: PWFDIBZWaveFunctions, forces=False) -> None: """Compute BZ from IBZ and distribute over the entire world!""" self.mypsits, _ = ibz2bz( ibzwfs, self.setups, self.relpos_ac, self.grid_local, self.plan, self.log if self.nupdates == 0 else None, forces) self.devc = 0.0 self.devv = 0.0 self.evv = 0.0 self.dekin = 0.0 self.nupdates += 1 def hybrid_energy_contributions(self) -> tuple[float, float, float, float]: devc = self.comm.sum_scalar(self.devc) devv = self.comm.sum_scalar(self.devv) dekin = -devc - 2 * devv energies = (self.exx_cc, devc, devv + self.evv, dekin + self.dekin) self.devc = np.nan self.devv = np.nan self.evv = np.nan self.dekin = np.nan return energies def move(self, relpos_av: np.ndarray) -> None: self.relpos_ac = relpos_av def apply_orbital_dependent(self, ibzwfs: IBZWaveFunctions, D_asii, psit2_nG: XArray, spin: int, Htpsit2_nG: XArray | None = None, calculate_energy: bool = False, F_av: np.ndarray | None = None) -> None: from gpaw.hybrids.paw import pawexxvv assert isinstance(psit2_nG, PWArray) assert Htpsit2_nG is None or isinstance(Htpsit2_nG, PWArray) assert isinstance(ibzwfs, PWFDIBZWaveFunctions) if F_av is not None: F1_av = np.zeros_like(F_av) else: F1_av = None # Find projectors and k-point weight for psit2_nG: for u, wfs in enumerate(ibzwfs): if wfs.spin != spin: continue if np.allclose(wfs.psit_nX.desc.kpt_c, psit2_nG.desc.kpt_c): pt_aiG = wfs.pt_aiX assert isinstance(pt_aiG, PWAtomCenteredFunctions) kweight = wfs.weight break else: # no break assert False, f'k-point not found: {psit2_nG.desc.kpt_c}' D_aii = D_asii[:, spin].copy() if ibzwfs.nspins == 1: D_aii = D_aii.copy() D_aii.data *= 0.5 # PAW-corrections: V_aii = D_aii.new() for a, D_ii in D_aii.items(): VV_ii = pawexxvv(self.VV_app[a], D_ii) VC_ii = self.VC_aii[a] V_ii = -VC_ii - 2 * VV_ii V_aii[a] = V_ii if not calculate_energy: continue if wfs.k > 0 or self.band_comm.rank > 0: # Doesn't depend on k continue if wfs.weight == 0.0: # zero-padding continue ec = (D_ii * VC_ii).sum() ev = (D_ii * VV_ii).sum() self.devv -= ev * ibzwfs.spin_degeneracy self.devc -= ec * ibzwfs.spin_degeneracy # distribute V_aii V2_aii = V_aii.gather(broadcast=True) if F1_av is not None and u == 0 and wfs.weight != 0.0: for a, V_ii in V2_aii.items(): for psit in self.mypsits: dP_anvi = psit.dP_anvi assert dP_anvi is not None force_v = np.einsum('ni, nvi, n -> v', psit.P_ani[a] @ V_ii, dP_anvi[a].conj(), psit.f_n).real force_v = 2 / self.nbzk * force_v F1_av[a] += force_v evv = self._apply1(spin, D_aii, pt_aiG, psit2_nG, Htpsit2_nG, kweight, wfs.myocc_n, V_aii, calculate_energy, F1_av) evv *= 0.5 * ibzwfs.spin_degeneracy if calculate_energy: self.evv += evv self.dekin -= 2 * evv if F1_av is not None: assert F_av is not None F_av += ibzwfs.spin_degeneracy * F1_av def _apply1(self, spin: int, D_aii, pt_aiG: PWAtomCenteredFunctions, psit_nG: PWArray, Htpsit_nG: PWArray | None, kweight: float, f_n: np.ndarray, V_aii, calculate_energy: bool, F1_av=None) -> float: comm = self.comm band_comm = self.band_comm domain_comm = psit_nG.desc.comm P_ani = pt_aiG.integrate(psit_nG) V0_ani = P_ani.new() for a, D_ii in D_aii.items(): V0_ani[a] = P_ani[a] @ V_aii[a] e = 0.0 for krank in range(self.kpt_comm.size): for brank in range(band_comm.size): data = None if krank == self.kpt_comm.rank and brank == band_comm.rank: psit2_nG = psit_nG.gather() P2_ani = P_ani.gather() if psit2_nG is not None: # Remove band_comm so that data can be pickled # when calling broadcast(data, ...) later: psit2_nG = psit2_nG[:] P2_ani = AtomArrays(P2_ani.layout, dims=(len(P2_ani.data),), data=P2_ani.data) data = (psit2_nG, P2_ani, f_n, spin, kweight) rank = (brank + krank * band_comm.size) * domain_comm.size psit2_nG, P2_ani, f2_n, s, w = broadcast(data, rank, comm=comm) V_nG = psit2_nG.new() V_nG.data[:] = 0.0 V_ani = P2_ani.new() V_ani.data[:] = 0.0 e += self._apply2(psit2_nG, P2_ani, s, V_nG, V_ani, f2_n, calculate_energy, w, F1_av) * w if Htpsit_nG is None: continue comm.sum(V_nG.data, root=rank) comm.sum(V_ani.data, root=rank) if krank == self.kpt_comm.rank: if brank == band_comm.rank: V2_nG = Htpsit_nG.new() V2_nG.scatter_from(V_nG) V2_ani = V0_ani.new() V2_ani.scatter_from(V_ani) V2_ani.data += V0_ani.data Htpsit_nG.data += V2_nG.data pt_aiG.add_to(Htpsit_nG, V2_ani) return e def _apply2(self, psit2_nG: PWArray, P2_ani: AtomArrays, spin: int, Htpsit2_nG, V2_ani, f2_n: np.ndarray, calculate_energy: bool, w: float, F1_av=None) -> float: ut2_nR = self.grid_local.empty(len(psit2_nG)) psit2_nG.ifft(out=ut2_nR, plan=self.plan, periodic=False) e = 0.0 pw2 = psit2_nG.desc for psit1 in self.mypsits: if psit1.spin == spin: pw = pw2.new(kpt=pw2.kpt_c - psit1.kpt_c) v_G = self.coulomb(pw) e += self._apply3( pw, v_G, psit1, ut2_nR, P2_ani, Htpsit2_nG, V2_ani, f2_n, calculate_energy, F1_av, w) e *= -self.exx_fraction / self.nbzk return self.comm.sum_scalar(e) # from line_profiler import profile # @profile def _apply3(self, pw: PWDesc, v_G: np.ndarray, psit1: Psit, ut2_nR: UGArray, P2_ani: AtomArrays, Htpsit2_nG: PWArray, V2_ani, f2_n: np.ndarray, calculate_energy: bool, F1_av: np.ndarray | None, w: float) -> float: ut1_nR = psit1.ut_nR Q1_aniL = psit1.Q_aniL f1_n = psit1.f_n ghat_aLG = self.setups.create_compensation_charges(pw, self.relpos_ac) ghat_aLG._lazy_init() ghat_GA = ghat_aLG._lfc.expand(cc=not self.real) N2 = len(ut2_nR) Q_anL = ghat_aLG.layout.empty(N2) rhot2_nG = pw.empty(N2) tmp_Q = self.plan.tmp_Q tmp_R = self.plan.tmp_R eikR_a = ghat_aLG._lfc.eikR_a pw2 = Htpsit2_nG.desc NR = tmp_R.size NG = pw.myshape[0] NG2 = pw2.myshape[0] tmp_G = np.empty(NG, complex) Q_G = pw.indices(tmp_Q.shape) Q2_G = pw2.indices(tmp_Q.shape) e = 0.0 for n1, ut1_R in enumerate(ut1_nR.data): f1 = f1_n[n1] for a, Q1_niL in Q1_aniL.items(): Q_anL[a] = P2_ani[a] @ Q1_niL[n1] * eikR_a[a].conj() if self.real: mmm(1.0 / pw.dv, Q_anL.data, 'N', ghat_GA, 'T', 0.0, rhot2_nG.data.view(float)) else: mmm(1.0 / pw.dv, Q_anL.data, 'N', ghat_GA, 'C', 0.0, rhot2_nG.data) for n2, (rhot_G, ut2_R) in enumerate(zip(rhot2_nG.data, ut2_nR.data)): tmp_R[:] = ut2_R tmp_R *= ut1_R.conj() self.plan.fft() a_G = tmp_Q.ravel()[Q_G] self.zaxpy(a_G, rhot_G, NG, 1.0 / NR) if not calculate_energy: rhot_G *= v_G else: tmp_G[:] = rhot_G rhot_G *= v_G e12 = tmp_G.view(float) @ rhot_G.view(float) if self.real: e12 = 2 * e12 - (tmp_G[0] * rhot_G[0]).real e += e12 * f2_n[n2] * f1 * pw.dv if F1_av is not None: forces(ghat_aLG, rhot2_nG, P2_ani, Q_anL, f1, f2_n, self.nbzk, self.delta_aiiL, psit1.dP_anvi, n1, eikR_a, F1_av, w) continue if self.real: ghat_GA[0] *= 0.5 mmm(2.0, rhot2_nG.data.view(float), 'N', ghat_GA, 'N', 0.0, Q_anL.data) ghat_GA[0] *= 2.0 else: mmm(1.0, rhot2_nG.data, 'N', ghat_GA, 'N', 0.0, Q_anL.data) x = self.exx_fraction * f1 / self.nbzk for rhot_G, Htpsit2_G in zip(rhot2_nG.data, Htpsit2_nG.data): self.plan.ifft_sphere(rhot_G, pw) tmp_R *= ut1_R.data self.plan.fft() # Htpsit2_G -= x / NR * pw2.cut(tmp_Q) v2_G = tmp_Q.ravel()[Q2_G] self.zaxpy(v2_G, Htpsit2_G, NG2, -x / NR) for a, Q1_niL in Q1_aniL.items(): V2_ani[a] -= x * Q_anL[a] @ Q1_niL[n1].T.conj() * eikR_a[a] return e def forces(ghat_aLG, vrhot2_nG, P2_ani, Q2_anL, f1, f2_n, nbzk, delta_aiiL, dP_anvi, n1, eikR_a, F_av, w): f12_n = f1 * f2_n w *= 1 / nbzk for a, F_nvL in ghat_aLG.derivative(vrhot2_nG).items(): F_av[a] -= 0.25 * w * np.einsum('n, nL, nvL -> v', f12_n, (Q2_anL[a] * eikR_a[a]).conj(), F_nvL).real for a, F_nL in ghat_aLG.integrate(vrhot2_nG).items(): F_iin = delta_aiiL[a] @ F_nL.T F_av[a] -= 0.5 * w * np.einsum('ijn, vi, nj, n -> v', F_iin, dP_anvi[a][n1], P2_ani[a].conj(), f12_n).real def non_self_consistent_hybrid_xc_energy( dft: DFT, xc: str, *, log: str | Path | IO[str] | Logger | None = '-') -> np.ndarray: """ The returned energy contributions are (in eV): 1. DFT total free energy (not extrapolated to zero smearing) 2. minus DFT XC energy 3. Hybrid semi-local XC energy 4. EXX core-core energy 5. EXX valence-core energy 6. EXX valence-valence energy """ if not isinstance(log, Logger): log = Logger(log, comm=dft.comm) ibzwfs = dft.ibzwfs exx = create_functional({'name': xc, 'backend': 'pw'}, dft.pot_calc.fine_grid) hybham = PWHybridHamiltonian( dft.density.grid, next(iter(ibzwfs)).psit_nX.desc, exx, dft.setups, dft.relpos_ac, dft.density.D_asii.layout.atomdist, log, ibzwfs.ibz.bz, ibzwfs.kpt_comm, ibzwfs.band_comm, dft.comm) ibzwfs.make_sure_wfs_are_read_from_gpw_file() assert isinstance(ibzwfs, PWFDIBZWaveFunctions) hybham.update_wave_functions(ibzwfs) edft = dft.energies.total_extrapolated exc = dft.energies._energies['xc'] log(f'DFT energy: {edft * Ha} eV') log(f'minus DFT-XC energy: {-exc * Ha} eV') log('Semi-local contribution:', end=' ', flush=True) t1 = time() semilocal_energy = _semilocal_xc_energy(dft, xc) t2 = time() log(f'{semilocal_energy * Ha:.3f} eV ({t2 - t1:.3f} seconds)') log('Calculating EXX contributions:', end=' ', flush=True) for wfs in ibzwfs.zero_padded_iter(): hybham.apply_orbital_dependent( ibzwfs, dft.density.D_asii, wfs.psit_nX, spin=wfs.spin, calculate_energy=True) t3 = time() log(f'{t3 - t2:.3f} seconds') ecc, evc, evv, _ = hybham.hybrid_energy_contributions() log(f'Core-core contribution: {ecc * Ha:12.3f}') log(f'Valence-core contribution: {evc * Ha:12.3f}') log(f'Valence-valence contribution: {evv * Ha:12.3f}', flush=True) return np.array( [dft.energies.total_extrapolated, -dft.energies._energies['xc'], semilocal_energy, ecc, evc, evv]) * Ha def _semilocal_xc_energy(dft: DFT, xc: str) -> float: from gpaw.hybrids import parse_name semilocal_xc_name, exx_fraction, exx_omega, yukawa = parse_name(xc) fine_grid = dft.pot_calc.fine_grid slxc = create_functional(semilocal_xc_name, fine_grid) nt_sr = dft.density.nt_sR.interpolate(grid=fine_grid) energy = 0.0 for a, D_sii in dft.density.D_asii.items(): D_sp = np.array([pack_density(D_ii.real) for D_ii in D_sii]) energy += slxc.calculate_paw_correction(dft.setups[a], D_sp) energy = dft.density.nt_sR.desc.comm.sum_scalar(energy) energy += slxc.calculate(nt_sr)[0] return energy