Source code for tad_libcint.interface.namemanager

# This file is part of tad-libcint, modified from diffqc/dqc.
#
# SPDX-License-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Original file licensed under the Apache License, Version 2.0 by diffqc/dqc.
# Modifications made by Grimme Group.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Interface: Integral Namemanager
===============================

The libcint interface is accessed via strings. This module provides the
corresponding name handling and manipulation.
"""

from __future__ import annotations

import copy
import re
from collections import defaultdict

from ..typing import Sequence
from .symmetry import s1

__all__ = ["IntorNameManager"]


[docs] class IntorNameManager: """ Class for integral name manipulation. This class should only perform string-manipulation and no array operations. """ # ops name must not contain sep name ops_name = ["ip", "rr"] # name of basis operators sep_name = ["a", "b"] # separator of basis (other than the middle operator) # components shape of raw operator and basis operators # should be a tuple with AT MOST 1 element rawop_comp = defaultdict( tuple, { # type: ignore "r0": (3,), "r0r0": (9,), "r0r0r0": (27,), "j": (3,), "jj": (9,), "jjj": (27,), "m": (3,), "mm": (9,), "mmm": (27,), "n": (3,), "nn": (9,), "nnn": (27,), }, ) op_comp = defaultdict( tuple, { # type: ignore "ip": (3,), }, ) # the number of new dimensions added with the operators rawop_ndim = defaultdict(int, {k: len(v) for (k, v) in rawop_comp.items()}) op_ndim = defaultdict(int, {k: len(v) for (k, v) in op_comp.items()}) def __init__(self, int_type: str, shortname: str): self._int_type = int_type self._shortname = shortname self._rawop, self._ops = self.split_name(int_type, shortname) self._nbasis = len(self._ops) # middle index (where the rawops should be) self._imid = (self._nbasis + 1) // 2 @property def fullname(self): return self._int_type + "_" + self._shortname @property def rawopname(self): return self._rawop @property def int_type(self): return self._int_type @property def shortname(self): return self._shortname @property def order(self) -> int: """ Get the order of the derivative of the integral. Returns ------- int Order of derivative. """ derivative_order = 0 for ops in self._ops: for op in ops: if op == "ip": derivative_order += 1 elif op == "ipip": derivative_order += 2 elif op == "ipipip": derivative_order += 3 return derivative_order
[docs] def get_intgl_name(self, spherical: bool) -> str: """ Get the full name of the integral in libcint library. Parameters ---------- spherical : bool Whether the integral is in spherical or cartesian coordinates. Returns ------- str Full name of the integral in libcint library. """ cartsph = "sph" if spherical else "cart" return f"{self.fullname}_{cartsph}"
[docs] def get_ft_intgl_name(self, spherical: bool) -> str: """ Get the full name of the Fourier transform integral in libcint library. Parameters ---------- spherical : bool Whether the integral is in spherical or cartesian coordinates. Returns ------- str Full name of the Fourier transform integral in libcint library. Raises ------ NotImplementedError If the Fourier transform integral is not implemented for the given integral type. """ cartsph = "sph" if spherical is True else "cart" int_type = self._int_type if int_type == "int1e": return f"GTO_ft_{self._shortname}_{cartsph}" raise NotImplementedError( f"FT integral for {int_type} not implemented." )
[docs] def get_intgl_deriv_namemgr( self, derivop: str, ibasis: int ) -> IntorNameManager: """ Get the name manager of a new integral when derivop is applied to ibasis-th basis. Parameters ---------- derivop : str String of the derivative operation. ibasis : int Which basis the derivative operation should be performed (0-based). Returns ------- IntorNameManager Name manager of the new integral. """ assert derivop in self.ops_name assert ibasis < self._nbasis ops = copy.copy(self._ops) ops[ibasis] = [derivop] + ops[ibasis] sname = self.join_name(self._int_type, self._rawop, ops) return IntorNameManager(self._int_type, sname)
[docs] def get_intgl_deriv_newaxispos( self, derivop: str, ibasis: int ) -> None | int: """ Get the new axis position in the new integral name when derivop is applied Parameters ---------- derivop : str String of the derivative operation. ibasis : int Which basis the derivative operation should be performed (0-based). Returns ------- None | int New axis position or None if no new axis is inserted. """ # get how many new axes the operator is going to add op_ndim = self.op_ndim[derivop] if op_ndim == 0: return None ops_flat: list[str] = sum(self._ops[:ibasis], []) new_ndim = sum(self.op_ndim[op] for op in ops_flat) # check if rawsname should also be included include_rname = ibasis >= self._imid if include_rname: new_ndim += self.rawop_ndim[self._rawop] return new_ndim
[docs] def get_intgl_components_shape(self) -> tuple[int, ...]: # returns the component shape of the array of the given integral ops_flat_l: list[str] = sum(self._ops[: self._imid], []) ops_flat_r: list[str] = sum(self._ops[self._imid :], []) comp_shape = ( sum([self.op_comp[op] for op in ops_flat_l], ()) + self.rawop_comp[self._rawop] + sum([self.op_comp[op] for op in ops_flat_r], ()) ) return comp_shape # type: ignore
[docs] def get_intgl_symmetry(self, _: Sequence[int]) -> s1.S1Symmetry: return s1.S1Symmetry()
[docs] def get_transpose_path_to( self, other: IntorNameManager ) -> list[tuple[int, int]] | None: """ Get the transpose path to the other integral. Check if the other integral can be achieved by transposing the current integral. Parameters ---------- other : IntorNameManager The other integral name manager. Returns ------- list[tuple[int, int]] | None Transpose path of `self` to get the same result as the `other` integral or `None` if it cannot be achieved. Raises ------ RuntimeError If the number of basis is not supported. """ nbasis = self._nbasis # get the basis transpose paths if nbasis == 2: transpose_paths: list[list[tuple[int, int]]] = [ [], [(-1, -2)], ] elif nbasis == 3: # NOTE: the third basis is usually an auxiliary basis which # typically different from the first two transpose_paths = [ [], [(-2, -3)], ] elif nbasis == 4: transpose_paths = [ [], [(-3, -4)], [(-1, -2)], [(-1, -3), (-2, -4)], [(-1, -3), (-2, -4), (-2, -1)], [(-1, -3), (-2, -4), (-3, -4)], ] else: raise self._nbasis_error(nbasis) def _swap( p: list[list[str]], path: list[tuple[int, int]] ) -> list[list[str]]: # swap the pattern according to the given transpose path r = p[:] # make a copy for i0, i1 in path: r[i0], r[i1] = r[i1], r[i0] return r # try all the transpose path until gets a match for transpose_path in transpose_paths: # pylint: disable=protected-access if _swap(self._ops, transpose_path) == other._ops: return transpose_path return None
[docs] def get_comp_permute_path( self, transpose_path: list[tuple[int, int]] ) -> list[int]: """ Get the component permute path given the basis transpose path. Parameters ---------- transpose_path : list[tuple[int, int]] Transpose path of the basis. Returns ------- list[int] Component permute path. """ # flat_ops: list[str] = sum(self._ops, []) # n_ip = flat_ops.count("ip") # get the positions of the axes dim_pos = [] ioffset = 0 for i, ops in enumerate(self._ops): if i == self._imid: naxes = self.rawop_ndim[self._rawop] dim_pos.append(list(range(ioffset, ioffset + naxes))) ioffset += naxes naxes = sum([self.op_ndim[op] for op in ops]) dim_pos.append(list(range(ioffset, ioffset + naxes))) ioffset += naxes # add the bases' axes (assuming each basis only occupy one axes) for i in range(self._nbasis): dim_pos.append([ioffset]) ioffset += 1 # swap the axes for t0, t1 in transpose_path: s0 = t0 + self._nbasis s1 = t1 + self._nbasis s0 += 1 if s0 >= self._imid else 0 s1 += 1 if s1 >= self._imid else 0 dim_pos[s0], dim_pos[s1] = dim_pos[s1], dim_pos[s0] # flatten the list to get the permutation path dim_pos_flat: list[int] = sum(dim_pos, []) return dim_pos_flat
[docs] @classmethod def split_name( cls, int_type: str, shortname: str ) -> tuple[str, list[list[str]]]: """ Split the shortname into operator per basis. Parameters ---------- int_type : str Type of the integral. shortname : str Shortname of the integral. Returns ------- tuple[str, list[list[str]]] Raw shortname (i.e., the middle operator) and list of basis-operator shortname. Raises ------ RuntimeError If the number of basis is not supported. """ deriv_ops = cls.ops_name deriv_pattern = re.compile("(" + ("|".join(deriv_ops)) + ")") # get the raw shortname (i.e. shortname without derivative operators) rawsname = shortname for op in deriv_ops: rawsname = rawsname.replace(op, "") nbasis = cls.get_nbasis(int_type) if nbasis == 2: ops_str = shortname.split(rawsname) elif nbasis == 3: assert rawsname.startswith("a"), rawsname rawsname = rawsname[1:] ops_l, ops_r = shortname.split(rawsname) ops_l1, ops_l2 = ops_l.split("a") ops_str = [ops_l1, ops_l2, ops_r] elif nbasis == 4: assert rawsname.startswith("a") and rawsname.endswith("b"), rawsname rawsname = rawsname[1:-1] ops_l, ops_r = shortname.split(rawsname) ops_l1, ops_l2 = ops_l.split("a") ops_r1, ops_r2 = ops_r.split("b") ops_str = [ops_l1, ops_l2, ops_r1, ops_r2] else: raise cls._nbasis_error(nbasis) ops = [re.findall(deriv_pattern, op_str) for op_str in ops_str] assert len(ops) == nbasis return rawsname, ops
[docs] @classmethod def join_name( cls, int_type: str, rawsname: str, ops: list[list[str]] ) -> str: """ Join the raw shortname and list of basis operators into a shortname. Parameters ---------- int_type : str Type of the integral. rawsname : str Raw shortname (i.e., the middle operator). ops : list[list[str]] List of basis-operator shortname. Returns ------- str Shortname of the integral. Raises ------ RuntimeError If the number of basis is not supported. """ nbasis = cls.get_nbasis(int_type) ops_str = ["".join(op) for op in ops] assert len(ops_str) == nbasis if nbasis == 2: return ops_str[0] + rawsname + ops_str[1] if nbasis == 3: return ( ops_str[0] + cls.sep_name[0] + ops_str[1] + rawsname + ops_str[2] ) if nbasis == 4: return ( ops_str[0] + cls.sep_name[0] + ops_str[1] + rawsname + ops_str[2] + cls.sep_name[1] + ops_str[3] ) raise cls._nbasis_error(nbasis)
[docs] @classmethod def get_nbasis(cls, int_type: str) -> int: """ Get the number of basis for the given integral type. Parameters ---------- int_type : str Type of the integral. Returns ------- int Number of basis. Raises ------ RuntimeError If the integral type is unknown. """ if int_type in ("int1e", "int2c2e"): return 2 if int_type == "int3c2e": return 3 if int_type == "int2e": return 4 raise RuntimeError(f"Unknown integral type: {int_type}")
@classmethod def _nbasis_error(cls, nbasis: int): return RuntimeError(f"Unknown integral with {nbasis} basis") def __str__(self): return ( f"{self.__class__.__name__}" f"(int_type={self._int_type!r}, shortname={self._shortname!r})" ) def __repr__(self) -> str: return str(self)