Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ uv.lock
# IDE
.vscode

# Scratch / temporary scripts
tmp/

# jupyter-book derived files
docs/_autosummary
docs/_build
Expand Down
107 changes: 45 additions & 62 deletions mess/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import jax.numpy as jnp
import jax.numpy.linalg as jnl
import optimistix as optx
from jaxtyping import Array, ScalarLike
from jaxtyping import ScalarLike

from mess.basis import Basis, renorm
from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis
from mess.integrals import kinetic_basis, nuclear_basis, overlap_basis
from mess.interop import to_pyscf
from mess.mesh import Mesh, density, density_and_grad, xcmesh_from_pyscf
from mess.orthnorm import symmetric
from mess.structure import nuclear_energy
from mess.two_electron import TwoElectron, ri_from_basis, isdf_thc_ri
from mess.types import FloatNxN, OrthNormTransform
from mess.xcfunctional import (
gga_correlation_lyp,
Expand All @@ -28,6 +29,7 @@

xcstr = Literal["lda", "pbe", "pbe0", "b3lyp", "hfx"]
IntegralBackend = Literal["mess", "pyscf_cart", "pyscf_sph"]
CoulombMethod = Literal["full", "ri", "thc-ri"]


class OneElectron(eqx.Module):
Expand Down Expand Up @@ -61,55 +63,14 @@ def __init__(self, basis: Basis, backend: IntegralBackend = "mess"):
self.nuclear = jnp.array(mol.intor(f"int1e_nuc_{kind}"))


class TwoElectron(eqx.Module):
eri: Array

def __init__(self, basis: Basis, backend: str = "mess"):
"""

Args:
basis (Basis): the basis set used to build the electron repulsion integrals
backend (str, optional): Integral backend used. Defaults to "mess".
"""
super().__init__()
if backend == "mess":
self.eri = eri_basis(basis)
elif backend.startswith("pyscf_"):
mol = to_pyscf(basis.structure, basis.basis_name)
kind = backend.split("_")[1]
self.eri = jnp.array(mol.intor(f"int2e_{kind}", aosym="s1"))

def coloumb(self, P: FloatNxN) -> FloatNxN:
"""Build the Coloumb matrix (classical electrostatic) from the density matrix.

Args:
P (FloatNxN): the density matrix

Returns:
FloatNxN: Coloumb matrix
"""
return jnp.einsum("kl,ijkl->ij", P, self.eri)

def exchange(self, P: FloatNxN) -> FloatNxN:
"""Build the quantum-mechanical exchange matrix from the density matrix

Args:
P (FloatNxN): the density matrix

Returns:
FloatNxN: Exchange matrix
"""
return jnp.einsum("ij,ikjl->kl", P, self.eri)


class HartreeFockExchange(eqx.Module):
two_electron: TwoElectron
two_electron: eqx.Module

def __init__(self, two_electron: TwoElectron):
def __init__(self, two_electron: eqx.Module):
self.two_electron = two_electron

def __call__(self, P: FloatNxN) -> ScalarLike:
K = self.two_electron.exchange(P)
def __call__(self, P: FloatNxN, C_occ=None) -> ScalarLike:
K = self.two_electron.exchange(P, C_occ)
return -0.25 * jnp.sum(P * K)


Expand All @@ -121,7 +82,7 @@ def __init__(self, basis: Basis):
self.basis = basis
self.mesh = xcmesh_from_pyscf(basis.structure)

def __call__(self, P: FloatNxN) -> ScalarLike:
def __call__(self, P: FloatNxN, C_occ=None) -> ScalarLike:
rho = density(self.basis, self.mesh, P)
eps_xc = lda_exchange(rho) + lda_correlation_vwn(rho)
E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, eps_xc)
Expand All @@ -136,7 +97,7 @@ def __init__(self, basis: Basis):
self.basis = basis
self.mesh = xcmesh_from_pyscf(basis.structure)

def __call__(self, P: FloatNxN) -> ScalarLike:
def __call__(self, P: FloatNxN, C_occ=None) -> ScalarLike:
rho, grad_rho = density_and_grad(self.basis, self.mesh, P)
eps_xc = gga_exchange_pbe(rho, grad_rho) + gga_correlation_pbe(rho, grad_rho)
E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, eps_xc)
Expand All @@ -148,40 +109,40 @@ class PBE0(eqx.Module):
mesh: Mesh
hfx: HartreeFockExchange

def __init__(self, basis: Basis, two_electron: TwoElectron):
def __init__(self, basis: Basis, two_electron: eqx.Module):
self.basis = basis
self.mesh = xcmesh_from_pyscf(basis.structure)
self.hfx = HartreeFockExchange(two_electron)

def __call__(self, P: FloatNxN) -> ScalarLike:
def __call__(self, P: FloatNxN, C_occ=None) -> ScalarLike:
rho, grad_rho = density_and_grad(self.basis, self.mesh, P)
e = 0.75 * gga_exchange_pbe(rho, grad_rho) + gga_correlation_pbe(rho, grad_rho)
E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, e)
return E_xc + 0.25 * self.hfx(P)
return E_xc + 0.25 * self.hfx(P, C_occ)


class B3LYP(eqx.Module):
basis: Basis
mesh: Mesh
hfx: HartreeFockExchange

def __init__(self, basis: Basis, two_electron: TwoElectron):
def __init__(self, basis: Basis, two_electron: eqx.Module):
self.basis = basis
self.mesh = xcmesh_from_pyscf(basis.structure)
self.hfx = HartreeFockExchange(two_electron)

def __call__(self, P: FloatNxN) -> ScalarLike:
def __call__(self, P: FloatNxN, C_occ=None) -> ScalarLike:
rho, grad_rho = density_and_grad(self.basis, self.mesh, P)
eps_x = 0.08 * lda_exchange(rho) + 0.72 * gga_exchange_b88(rho, grad_rho)
vwn_c = (1 - 0.81) * lda_correlation_vwn(rho)
lyp_c = 0.81 * gga_correlation_lyp(rho, grad_rho)
b3lyp_xc = eps_x + vwn_c + lyp_c
E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, b3lyp_xc)
return E_xc + 0.2 * self.hfx(P)
return E_xc + 0.2 * self.hfx(P, C_occ)


def build_xcfunc(
xc_method: xcstr, basis: Basis, two_electron: Optional[TwoElectron] = None
xc_method: xcstr, basis: Basis, two_electron: Optional[eqx.Module] = None
) -> eqx.Module:
if two_electron is None and xc_method in ("pbe0", "b3lyp"):
raise ValueError(
Expand Down Expand Up @@ -211,7 +172,7 @@ class Hamiltonian(eqx.Module):
X: FloatNxN
H_core: FloatNxN
basis: Basis
two_electron: TwoElectron
two_electron: eqx.Module
xcfunc: eqx.Module

def __init__(
Expand All @@ -220,19 +181,37 @@ def __init__(
ont: OrthNormTransform = symmetric,
xc_method: xcstr = "lda",
backend: IntegralBackend = "pyscf_sph",
coulomb: CoulombMethod = "full",
two_electron: Optional[eqx.Module] = None,
):
super().__init__()
self.basis = renorm(basis, backend) if backend != "mess" else basis
one_elec = OneElectron(basis, backend=backend)
S = one_elec.overlap
self.X = ont(S)
self.H_core = one_elec.kinetic + one_elec.nuclear
self.two_electron = TwoElectron(basis, backend=backend)
if two_electron is not None:
self.two_electron = two_electron
else:
match coulomb:
case "full":
self.two_electron = TwoElectron(basis, backend=backend)
case "ri":
self.two_electron = ri_from_basis(basis)
case "thc-ri":
mesh = xcmesh_from_pyscf(basis.structure, level=0)
self.two_electron = isdf_thc_ri(basis, mesh, c_isdf=3.0)
case _:
methods = get_args(CoulombMethod)
raise ValueError(
f"Unknown coulomb method: {coulomb}. "
f"Must be one of: {', '.join(methods)}"
)
self.xcfunc = build_xcfunc(xc_method, self.basis, self.two_electron)

def __call__(self, P: FloatNxN) -> ScalarLike:
def __call__(self, P: FloatNxN, C_occ=None) -> ScalarLike:
E_core = jnp.sum(self.H_core * P)
E_xc = self.xcfunc(P)
E_xc = self.xcfunc(P, C_occ)
J = self.two_electron.coloumb(P)
E_es = 0.5 * jnp.sum(J * P)
E = E_core + E_xc + E_es
Expand Down Expand Up @@ -279,13 +258,17 @@ def minimise(
def f(Z, _):
C = H.orthonormalise(Z)
P = H.basis.density_matrix(C)
return H(P)
n_occ = H.basis.structure.num_electrons // 2
C_occ = C[:, :n_occ]
return H(P, C_occ=C_occ)

solver = optx.BestSoFarMinimiser(solver)
Z = initial_guess_fn(H.basis)
sol = optx.minimise(f, solver, Z, max_steps=max_steps)
C = H.orthonormalise(sol.value)
P = H.basis.density_matrix(C)
E_elec = H(P)
n_occ = H.basis.structure.num_electrons // 2
C_occ = C[:, :n_occ]
E_elec = H(P, C_occ=C_occ)
E_total = E_elec + nuclear_energy(H.basis.structure)
return E_total, C, sol
Loading
Loading