from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Generic, Literal, TypeVar
import numpy as np
from ase.io.ulm import NDArrayReader
import gpaw.fftw as fftw
from gpaw.core.domain import Domain
from gpaw.core.matrix import Matrix
from gpaw.gpu import XP
from gpaw.mpi import MPIComm
from gpaw.new import trace
from gpaw.typing import Array1D, ArrayND, Self
if TYPE_CHECKING:
from gpaw.core.uniform_grid import UGArray, UGDesc
from gpaw.new import prod
DomainType = TypeVar('DomainType', bound=Domain)
class XArrayWithNoData:
def __init__(self,
comm,
dims,
desc,
xp):
self.comm = comm
self.dims = dims
self.desc = desc
self.xp = xp
self.data = None
def morph(self, desc):
from gpaw.new.calculation import ReuseWaveFunctionsError
raise ReuseWaveFunctionsError
[docs]
class XArray(Generic[DomainType], XP):
desc: DomainType
def __init__(self,
dims: int | tuple[int, ...],
myshape: tuple[int, ...],
comm: MPIComm,
domain_comm: MPIComm,
data: np.ndarray | None,
dv: float,
dtype,
xp=None):
self.myshape = myshape
self.comm = comm
self.domain_comm = domain_comm
self.dv = dv
# convert int to tuple:
self.dims = dims if isinstance(dims, tuple) else (dims,)
if self.dims:
mydims0 = (self.dims[0] + comm.size - 1) // comm.size
d1 = min(comm.rank * mydims0, self.dims[0])
d2 = min((comm.rank + 1) * mydims0, self.dims[0])
mydims0 = d2 - d1
self.mydims = (mydims0,) + self.dims[1:]
else:
self.mydims = ()
fullshape = self.mydims + self.myshape
if data is not None:
if data.shape != fullshape:
raise ValueError(
f'Bad shape for data: {data.shape} != {fullshape}')
if data.dtype != dtype:
raise ValueError(
f'Bad dtype for data: {data.dtype} != {dtype}')
if xp is not None:
assert (xp is np) == isinstance(
data, (np.ndarray, NDArrayReader)), xp
else:
data = (xp or np).empty(fullshape, dtype)
self.data = data
if isinstance(data, (np.ndarray, NDArrayReader)):
xp = np
else:
from gpaw.gpu import cupy as cp
xp = cp
super().__init__(xp)
self._matrix: Matrix | None = None
[docs]
def new(self, data=None, dims=None) -> XArray:
raise NotImplementedError
[docs]
def create_work_buffer(self, data_buffer: np.ndarray):
"""Create new Distributed array object of same
kind, to be used as a buffer array when doing
sliced operations.
Parameters
----------
data_buffer:
Array to use for storage.
"""
assert isinstance(data_buffer, self.xp.ndarray)
assert len(self.dims) >= 1
data_buffer = data_buffer.view(self.data.dtype)
datasize = data_buffer.size
X = self.data.shape[1:]
nX = int(np.prod(X))
# Choose mybands, s.t. they fit into
# data_buffer. Hence, datasize divided by nX
# rounded down.
if nX == 0:
mybands = self.data.shape[0]
else:
mybands = min(datasize // nX,
self.data.shape[0])
mybands = self.desc.comm.min_scalar(mybands)
data = data_buffer[:mybands * nX].reshape(
(mybands,) + X)
totalbands = self.comm.sum_scalar(mybands)
# Dims is (totalbands,) + self.dims[1:], where
# self.dims[1:] is extra dimensions, such as spin.
return self.new(data=data,
dims=(totalbands,) + self.dims[1:])
[docs]
def copy(self):
return self.new(data=self.data.copy())
[docs]
def sanity_check(self) -> None:
"""Sanity check."""
pass
def __getitem__(self, index):
raise NotImplementedError
def __bool__(self):
raise ValueError
def __len__(self):
return self.dims[0]
def __iter__(self):
for index in range(self.dims[0]):
yield self[index]
[docs]
def flat(self) -> Self:
if self.dims == ():
yield self
else:
for index in np.indices(self.dims).reshape((len(self.dims), -1)).T:
yield self[tuple(index)]
[docs]
def to_xp(self, xp):
if xp is self.xp:
assert xp is np, 'cp -> cp should not be needed!'
return self
if xp is np:
return self.new(data=self.xp.asnumpy(self.data))
else:
return self.new(data=xp.asarray(self.data))
@property
def matrix(self) -> Matrix:
if self._matrix is not None:
return self._matrix
nx = prod(self.myshape)
shape = (self.dims[0], prod(self.dims[1:]) * nx)
myshape = (self.mydims[0], prod(self.mydims[1:]) * nx)
dist = (self.comm, -1, 1)
data = self.data.reshape(myshape)
self._matrix = Matrix(*shape, data=data, dist=dist)
return self._matrix
[docs]
@trace
def matrix_elements(self,
other: Self,
*,
out: Matrix | None = None,
symmetric: bool | Literal['_default'] = '_default',
function=None,
domain_sum=True,
cc: bool = False) -> Matrix:
if symmetric == '_default':
symmetric = self is other
comm = self.comm
if out is None:
out = Matrix(self.dims[0], other.dims[0],
dist=(comm, -1, 1),
dtype=self.desc.dtype,
xp=self.xp)
if comm.size == 1:
assert other.comm.size == 1
if function:
assert symmetric
other = function(other)
M1 = self.matrix
M2 = other.matrix
n = M1.shape[0]
m = M2.shape[0]
X = M1.shape[1]
assert M2.shape[1] == X
# Slice the inner product into blocks, to improve
# numerical stability. This is especially important
# for single precision.
if self.data.dtype in (np.float32, np.complex64):
blocksize = 4096
# 4096 = 2**12. Largest blocksize, that yields
# good numerical stability. This results in some
# overhead, however, numericaly stability is
# more important. In the future, we might want
# to find improvements.
elif self.data.dtype in (np.float64, np.complex128):
blocksize = 16777216
# Double is simply just the blocksize of
# single precision squared 2**24. Most likely,
# we will never end up slicing the matrix into
# blocks for double precision.
for ind in range(0, max(X, 1), blocksize):
m1 = Matrix(n,
min(blocksize, X - ind),
data=M1.data[:, ind:ind + blocksize],
xp=self.xp,
dist=(comm, -1, 1))
m2 = Matrix(m,
min(blocksize, X - ind),
data=M2.data[:, ind:ind + blocksize],
xp=self.xp,
dist=(comm, -1, 1))
m1.multiply(m2, opb='C', alpha=self.dv,
symmetric=symmetric, out=out,
beta=0 if ind == 0 else 1)
# functions needs a correction:
self._matrix_elements_correction(M1, M2, out, symmetric)
else:
if symmetric:
_parallel_me_sym(self, out, function)
else:
_parallel_me(self, other, out)
if not cc:
out.complex_conjugate()
if domain_sum:
self.domain_comm.sum(out.data)
return out
def _matrix_elements_correction(self,
M1: Matrix,
M2: Matrix,
out: Matrix,
symmetric: bool) -> None:
"""Hook for PlaneWaveExpansion."""
pass
[docs]
def abs_square(self,
weights: Array1D,
out: UGArray) -> None:
"""Add weighted absolute square of data to output array.
See also :xkcd:`849`.
"""
raise NotImplementedError
[docs]
def add_ked(self,
weights: Array1D,
out: UGArray) -> None:
"""Add weighted absolute square of gradient of data to output array."""
raise NotImplementedError
[docs]
def gather(self, out=None, broadcast=False):
raise NotImplementedError
[docs]
def gathergather(self):
a_xX = self.gather() # gather X (grid-points or plane-waves)
if a_xX is not None:
m_xX = a_xX.matrix.gather() # gather x
if m_xX is not None:
data = m_xX.data
if a_xX.data.dtype != data.dtype:
data = data.view(complex)
return self.desc.new(comm=None).from_data(data)
[docs]
def scatter_from(self, data: ArrayND | None = None) -> None:
raise NotImplementedError
[docs]
def redist(self,
domain,
comm1: MPIComm, comm2: MPIComm) -> XArray:
"""Redistribute to new domain.
The "world" is spanned by::
(self.desc.comm, comm1)
and::
(domain.comm, comm2).
"""
result = domain.empty(self.dims, xp=self.xp)
if comm1.rank == 0:
a = self.gather()
else:
a = None
if comm2.rank == 0:
result.scatter_from(a)
comm2.broadcast(result.data, 0)
return result
[docs]
def interpolate(self,
*,
plan1: fftw.FFTPlans | None = None,
plan2: fftw.FFTPlans | None = None,
grid: UGDesc | None = None,
out: UGArray | None = None) -> UGArray:
raise NotImplementedError
[docs]
def integrate(self, other: Self | None = None) -> np.ndarray:
raise NotImplementedError
[docs]
def norm2(self,
kind: str = 'normal',
weights: np.ndarray | None = None,
skip_sum=False) -> np.ndarray:
raise NotImplementedError
[docs]
def trace_inner_product(self, other: Self) -> float:
raise NotImplementedError
def _parallel_me(psit1_nX: XArray,
psit2_nX: XArray,
M_nn: Matrix) -> None:
comm = psit2_nX.comm
nbands = psit2_nX.dims[0]
psit1_nX = psit1_nX[:]
B = (nbands + comm.size - 1) // comm.size
n_r = [min(r * B, nbands) for r in range(comm.size + 1)]
xp = psit1_nX.xp
buf1_nX = psit1_nX.desc.empty(B, xp=xp)
buf2_nX = psit1_nX.desc.empty(B, xp=xp)
psit_nX = psit2_nX
for shift in range(comm.size):
rrequest = None
srequest = None
if shift < comm.size - 1:
srank = (comm.rank + shift + 1) % comm.size
rrank = (comm.rank - shift - 1) % comm.size
n1 = n_r[rrank]
n2 = n_r[rrank + 1]
mynb = n2 - n1
if mynb > 0:
rrequest = comm.receive(buf1_nX.data[:mynb], rrank, 11, False)
if psit2_nX.data.size > 0:
srequest = comm.send(psit2_nX.data, srank, 11, False)
r2 = (comm.rank - shift) % comm.size
n1 = n_r[r2]
n2 = n_r[r2 + 1]
m_nn = psit1_nX.matrix_elements(psit_nX[:n2 - n1],
cc=True, domain_sum=False)
M_nn.data[:, n1:n2] = m_nn.data
if rrequest:
comm.wait(rrequest)
if srequest:
comm.wait(srequest)
psit_nX = buf1_nX
buf1_nX, buf2_nX = buf2_nX, buf1_nX
def _parallel_me_sym(psit1_nX: XArray,
M_nn: Matrix,
operator: None | Callable[[XArray],
XArray]
) -> None:
"""..."""
comm = psit1_nX.comm
nbands = psit1_nX.dims[0]
B = (nbands + comm.size - 1) // comm.size
mynbands = psit1_nX.mydims[0]
n_r = [min(r * B, nbands) for r in range(comm.size + 1)]
mynbands_r = [n_r[r + 1] - n_r[r] for r in range(comm.size)]
assert mynbands_r[comm.rank] == mynbands
xp = psit1_nX.xp
psit2_nX = psit1_nX
buf1_nX = psit1_nX.desc.empty(B, xp=xp)
buf2_nX = psit1_nX.desc.empty(B, xp=xp)
half = comm.size // 2
for shift in range(half + 1):
rrequest = None
srequest = None
if shift < half:
srank = (comm.rank + shift + 1) % comm.size
rrank = (comm.rank - shift - 1) % comm.size
skip = comm.size % 2 == 0 and shift == half - 1
rmynb = mynbands_r[rrank]
if not (skip and comm.rank < half) and rmynb > 0:
rrequest = comm.receive(buf1_nX.data[:rmynb], rrank, 11, False)
if not (skip and comm.rank >= half) and psit1_nX.data.size > 0:
srequest = comm.send(psit1_nX.data, srank, 11, False)
if shift == 0:
if operator is not None:
op_psit1_nX = operator(psit1_nX)
else:
op_psit1_nX = psit1_nX
op_psit1_nX = op_psit1_nX[:] # local view
if not (comm.size % 2 == 0 and shift == half and comm.rank < half):
r2 = (comm.rank - shift) % comm.size
n1 = n_r[r2]
n2 = n_r[r2 + 1]
m_nn = op_psit1_nX.matrix_elements(psit2_nX[:n2 - n1],
symmetric=(shift == 0),
cc=True, domain_sum=False)
M_nn.data[:, n1:n2] = m_nn.data
if rrequest:
comm.wait(rrequest)
if srequest:
comm.wait(srequest)
psit2_nX = buf1_nX
buf1_nX, buf2_nX = buf2_nX, buf1_nX
requests = []
blocks = []
nrows = (comm.size - 1) // 2
for row in range(nrows):
for column in range(comm.size - nrows + row, comm.size):
if comm.rank == row:
n1 = n_r[column]
n2 = n_r[column + 1]
if mynbands > 0 and n2 > n1:
requests.append(
comm.send(M_nn.data[:, n1:n2].T.conj().copy(),
column, 12, False))
elif comm.rank == column:
n1 = n_r[row]
n2 = n_r[row + 1]
if mynbands > 0 and n2 > n1:
block = xp.empty((mynbands, n2 - n1), M_nn.dtype)
blocks.append((n1, n2, block))
requests.append(comm.receive(block, row, 12, False))
comm.waitall(requests)
for n1, n2, block in blocks:
M_nn.data[:, n1:n2] = block