from math import pi, sqrt
from warnings import warn
import numpy as np
from ase.units import Bohr, Ha
from gpaw.atom.shapefunc import shape_functions
from gpaw.fftw import get_efficient_fft_size
from gpaw.lfc import LocalizedFunctionsCollection as LFC
from gpaw.mpi import serial_comm
from gpaw.old.calculator import GPAW
from gpaw.old.grid_descriptor import GridDescriptor
from gpaw.old.pw.descriptor import PWDescriptor
from gpaw.setup import Setup
from gpaw.spline import Spline
from gpaw.typing import Array3D
from gpaw.utilities import h2gpts
class Interpolator:
def __init__(self, gd1, gd2, dtype=float):
self.pd1 = PWDescriptor(0.0, gd1, dtype)
self.pd2 = PWDescriptor(0.0, gd2, dtype)
def interpolate(self, a_r):
return self.pd1.interpolate(a_r, self.pd2)[0]
POINTS = 200
[docs]
class PS2AE:
"""Transform PS to AE wave functions.
Interpolates PS wave functions to a fine grid and adds PAW
corrections in order to obtain true AE wave functions.
"""
def __init__(self,
calc: GPAW,
grid_spacing: float = 0.05,
n: int = 2,
h=None # deprecated
):
"""Create transformation object.
calc: GPAW calculator object
The calculator that has the wave functions.
grid_spacing: float
Desired grid-spacing in Angstrom.
n: int
Force number of points to be a multiple of n.
"""
if h is not None:
warn('Please use grid_spacing=... instead of h=...')
grid_spacing = h
self.calc = calc
gd = calc.wfs.gd
gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm)
# Descriptor for the final grid:
N_c = h2gpts(grid_spacing / Bohr, gd.cell_cv)
N_c = np.array([get_efficient_fft_size(N, n) for N in N_c])
gd2 = self.gd = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm)
self.interpolator = Interpolator(gd1, gd2, self.calc.wfs.dtype)
self._dphi: LFC | None = None # PAW correction
self.dv = self.gd.dv * Bohr**3
@property
def dphi(self) -> LFC:
if self._dphi is not None:
return self._dphi
splines: dict[Setup, list[Spline]] = {}
dphi_aj = []
for setup in self.calc.wfs.setups:
dphi_j = splines.get(setup)
if dphi_j is None:
rcut = max(setup.rcut_j) * 1.1
gcut = setup.rgd.ceil(rcut)
dphi_j = []
for l, phi_g, phit_g in zip(setup.l_j,
setup.data.phi_jg,
setup.data.phit_jg):
dphi_g = (phi_g - phit_g)[:gcut]
dphi_j.append(setup.rgd.spline(dphi_g, rcut, l,
points=200))
splines[setup] = dphi_j
dphi_aj.append(dphi_j)
self._dphi = LFC(self.gd, dphi_aj, kd=self.calc.wfs.kd.copy(),
dtype=self.calc.wfs.dtype)
self._dphi.set_positions(self.calc.spos_ac)
return self._dphi
[docs]
def get_wave_function(self,
n: int,
k: int = 0,
s: int = 0,
ae: bool = True,
periodic: bool = False) -> Array3D:
"""Interpolate wave function.
Returns 3-d array in units of Ång**-1.5.
n: int
Band index.
k: int
K-point index.
s: int
Spin index.
ae: bool
Add PAW correction to get an all-electron wave function.
periodic:
Return periodic part of wave-function, u(r), instead of
psi(r)=exp(ikr)u(r).
"""
u_r = self.calc.get_pseudo_wave_function(n, k, s,
periodic=True)
u_R = self.interpolator.interpolate(u_r * Bohr**1.5)
k_c = self.calc.wfs.kd.ibzk_kc[k]
gamma = np.isclose(k_c, 0.0).all()
if gamma:
eikr_R = 1.0
else:
eikr_R = self.gd.plane_wave(k_c)
if ae:
dphi = self.dphi
wfs = self.calc.wfs
P_nI = wfs.collect_projections(k, s)
if wfs.world.rank == 0:
psi_R = u_R * eikr_R
P_ai = {}
I1 = 0
for a, setup in enumerate(wfs.setups):
I2 = I1 + setup.ni
P_ai[a] = P_nI[n, I1:I2]
I1 = I2
dphi.add(psi_R, P_ai, k)
u_R = psi_R / eikr_R
wfs.world.broadcast(u_R, 0)
if periodic:
return u_R * Bohr**-1.5
else:
return u_R * eikr_R * Bohr**-1.5
[docs]
def get_pseudo_density(self,
add_compensation_charges: bool = True) -> Array3D:
"""Interpolate pseudo density."""
dens = self.calc.density
gd1 = dens.gd
assert gd1.comm.size == 1
interpolator = Interpolator(gd1, self.gd)
dens_r = dens.nt_sG[:dens.nspins].sum(axis=0)
dens_R = interpolator.interpolate(dens_r)
if add_compensation_charges:
dens.calculate_multipole_moments()
ghat = LFC(self.gd, [setup.ghat_l for setup in dens.setups],
integral=sqrt(4 * pi))
ghat.set_positions(self.calc.spos_ac)
Q_aL = {}
for a, Q_L in dens.Q_aL.items():
Q_aL[a] = Q_L.copy()
Q_aL[a][0] += dens.setups[a].Nv / (4 * pi)**0.5
ghat.add(dens_R, Q_aL)
return dens_R / Bohr**3
[docs]
def get_electrostatic_potential(self,
ae: bool = True,
rcgauss: float = 0.02) -> Array3D:
"""Interpolate electrostatic potential.
Return value in eV.
ae: bool
Add PAW correction to get the all-electron potential.
rcgauss: float
Width of gaussian (in Angstrom) used to represent the nuclear
charge.
"""
gd = self.calc.hamiltonian.finegd
v_r = self.calc.get_electrostatic_potential() / Ha
gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm)
interpolator = Interpolator(gd1, self.gd)
v_R = interpolator.interpolate(v_r)
if ae:
self.add_potential_correction(v_R, rcgauss / Bohr)
return v_R * Ha
def add_potential_correction(self,
v_R: Array3D,
rcgauss: float) -> None:
dens = self.calc.density
dens.D_asp.redistribute(dens.atom_partition.as_serial())
dens.Q_aL.redistribute(dens.atom_partition.as_serial())
dv_a1 = []
for a, D_sp in dens.D_asp.items():
setup = dens.setups[a]
c = setup.xc_correction
rgd = c.rgd
params = setup.data.shape_function.copy()
params['lmax'] = 0
ghat_g = shape_functions(rgd, **params)[0]
Z_g = shape_functions(rgd, 'gauss', rcgauss, lmax=0)[0] * setup.Z
D_q = np.dot(D_sp.sum(0), c.B_pqL[:, :, 0])
dn_g = np.dot(D_q, (c.n_qg - c.nt_qg)) * sqrt(4 * pi)
dn_g += 4 * pi * (c.nc_g - c.nct_g)
dn_g -= Z_g
dn_g -= dens.Q_aL[a][0] * ghat_g * sqrt(4 * pi)
dv_g = rgd.poisson(dn_g) / sqrt(4 * pi)
dv_g[1:] /= rgd.r_g[1:]
dv_g[0] = dv_g[1]
dv_g[-1] = 0.0
dv_a1.append([rgd.spline(dv_g, points=POINTS)])
dens.D_asp.redistribute(dens.atom_partition)
dens.Q_aL.redistribute(dens.atom_partition)
if dv_a1:
dv = LFC(self.gd, dv_a1)
dv.set_positions(self.calc.spos_ac)
dv.add(v_R)
dens.gd.comm.broadcast(v_R, 0)
def interpolate_weight(calc, weight, h=0.05, n=2):
"""interpolates cdft weight function, gd is the fine grid."""
gd = calc.density.finegd
weight = gd.collect(weight, broadcast=True)
weight = gd.zero_pad(weight)
w = np.zeros_like(weight)
gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm)
gd1.distribute(weight, w)
N_c = h2gpts(h / Bohr, gd.cell_cv)
N_c = np.array([get_efficient_fft_size(N, n) for N in N_c])
gd2 = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm)
interpolator = Interpolator(gd1, gd2)
W = interpolator.interpolate(w)
return W