Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions dptb/postprocess/elec_struc_cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class ElecStruCal(object):
def __init__ (
self,
model: torch.nn.Module,
device: Union[str, torch.device]=None
device: Union[str, torch.device]=None,
**kwargs
):
'''It initializes ElecStruCal object with a neural network model, optional results path, GUI
usage flag, and device information, and sets up eigenvalues based on model properties.
Expand Down Expand Up @@ -69,12 +70,18 @@ def __init__ (
)
r_max, er_max, oer_max = get_cutoffs_from_model_options(model.model_options)
self.cutoffs = {'r_max': r_max, 'er_max': er_max, 'oer_max': oer_max}

if 'override_overlap' in kwargs and isinstance(kwargs['override_overlap'], str):
self.override_overlap = kwargs['override_overlap']
else:
self.override_overlap = None

def get_data(self,
data: Union[AtomicData, ase.Atoms, str],
pbc:Union[bool,list]=None,
device: Union[str, torch.device]=None,
AtomicData_options:dict=None,
override_overlap:Optional[str]=None):
override_overlap:Union[str,bool,None]=None):
'''The function `get_data` takes input data in the form of a string, ase.Atoms object, or AtomicData
object, processes it accordingly, and returns the AtomicData class.

Expand All @@ -89,14 +96,15 @@ def get_data(self,
device : Union[str, torch.device]
The `device` parameter in the `get_data` function is used to specify the device on which the data
should be processed. If no device is provided, it defaults to `self.device`.
override_overlap : the path for overlap.h5 to use and override overlap matrix from model.
override_overlap : the path for overlap.h5 to use and override overlap matrix from model. If None, will try
to use self.override_overlap; If False, will not try anything.

Returns
-------
the loaded AtomicData object.

'''
if override_overlap is not None:
if override_overlap is not False and override_overlap or self.override_overlap:
if not self.overlap:
self.eigv = Eigenvalues(
idp=self.model.idp,
Expand All @@ -112,7 +120,7 @@ def get_data(self,
device=device if device else self.device,
pbc=pbc,
AtomicData_options=AtomicData_options,
override_overlap=override_overlap
override_overlap=None if override_overlap == False else override_overlap if override_overlap else self.override_overlap
)


Expand All @@ -121,7 +129,7 @@ def get_eigs(self,
klist: np.ndarray,
pbc:Union[bool,list]=None,
AtomicData_options:dict=None,
override_overlap:Optional[str]=None,
override_overlap:Union[str,bool,None]=None,
eig_solver:Optional[str]=None):
'''This function calculates eigenvalues for Hk at specified k-points.

Expand All @@ -142,7 +150,8 @@ def get_eigs(self,
The function `get_eigs` returns the loaded data and the energy eigenvalues as a numpy array.

'''


override_overlap = None if override_overlap == False else override_overlap if override_overlap else self.override_overlap
data = self.get_data(data=data, pbc=pbc, device=self.device,AtomicData_options=AtomicData_options, override_overlap=override_overlap)
# set the kpoint of the AtomicData
data[AtomicDataDict.KPOINT_KEY] = \
Expand Down
9 changes: 7 additions & 2 deletions dptb/postprocess/unified/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_hR(self, atomic_data: dict) -> Tuple[Any, Any]:
pass

@abstractmethod
def get_eigenvalues(self, atomic_data: dict) -> Tuple[dict, torch.Tensor]:
def get_eigenvalues(self, atomic_data: dict, **kwargs) -> Tuple[dict, torch.Tensor]:
"""
Calculate eigenvalues for the given atomic data.

Expand Down Expand Up @@ -225,7 +225,12 @@ def get_hR(self, atomic_data):
def get_eigenvalues(self,
atomic_data: dict,
nk: Optional[int]=None,
solver: Optional[str]=None) -> Tuple[dict, torch.Tensor]:
solver: Optional[str]=None,
**kwargs) -> Tuple[dict, torch.Tensor]:
if not nk and kwargs.get("nk"):
nk = kwargs.get("nk")
if not solver and kwargs.get("eig_solver") or kwargs.get("solver"):
solver = kwargs.get("eig_solver") if kwargs.get("eig_solver") else kwargs.get("solver")
# 1. Get Hamiltonian
atomic_data = self.model_forward(atomic_data)

Expand Down
6 changes: 4 additions & 2 deletions dptb/postprocess/unified/properties/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def set_kpath(self, method: str, **kwargs):
k_tensor = torch.as_tensor(self._k_points, dtype=self._system._calculator.dtype, device=self._system._calculator.device)
self._system._atomic_data[AtomicDataDict.KPOINT_KEY] = torch.nested.as_nested_tensor([k_tensor])

def compute(self):
def compute(self, **kwargs):
"""
Compute the band structure using the configured K-path and store result in system.
"""
Expand All @@ -232,7 +232,9 @@ def compute(self):
data = self._system._atomic_data

# Calculate
data, eigs = self._system.calculator.get_eigenvalues(data)
data, eigs = self._system.calculator.get_eigenvalues(data,
nk=kwargs.get("nk", None),
eig_solver=kwargs.get("eig_solver", None))
Comment on lines +235 to +237
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Only forward solver kwargs when they were explicitly requested.

TBSystem still allows duck-typed calculators, so always calling get_eigenvalues(..., nk=None, eig_solver=None) is a compatibility regression for calculators that still implement the old get_eigenvalues(data) signature. Build the kwargs dict conditionally before the call.

🔧 Safer forwarding
-        data, eigs = self._system.calculator.get_eigenvalues(data,
-                                                             nk=kwargs.get("nk", None),
-                                                             eig_solver=kwargs.get("eig_solver", None))
+        calc_kwargs = {}
+        if kwargs.get("nk") is not None:
+            calc_kwargs["nk"] = kwargs["nk"]
+        if kwargs.get("eig_solver") is not None:
+            calc_kwargs["eig_solver"] = kwargs["eig_solver"]
+
+        data, eigs = self._system.calculator.get_eigenvalues(data, **calc_kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@dptb/postprocess/unified/properties/band.py` around lines 235 - 237, The call
to self._system.calculator.get_eigenvalues currently always forwards nk and
eig_solver with None which breaks compatibility with calculators implementing
the old get_eigenvalues(data) signature; instead, build a small kwargs dict
(e.g., eig_kwargs) conditionally: add 'nk' only if 'nk' in kwargs and add
'eig_solver' only if 'eig_solver' in kwargs, then call
self._system.calculator.get_eigenvalues(data, **eig_kwargs). This preserves
compatibility with TBSystem calculators while still forwarding explicitly
provided arguments.


# Extract results
eigenvalues = eigs.detach().cpu().numpy() # [Nk, Nb]
Expand Down
12 changes: 9 additions & 3 deletions dptb/postprocess/unified/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def get_efermi(self, kmesh: List[int],
device=self.calculator.device)
data = self._atomic_data.copy()
data[AtomicDataDict.KPOINT_KEY] = torch.nested.as_nested_tensor([k_tensor])
data, eigs = self.calculator.get_eigenvalues(data)
data, eigs = self.calculator.get_eigenvalues(data,
nk=kwargs.get("nk", None),
eig_solver=kwargs.get("eig_solver", None))
Comment on lines +243 to +245
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid unconditional keyword injection into protocol calculators.

At Line 243, get_eigenvalues is called with nk= and eig_solver= unconditionally (even when both are None). Since Line 46 allows any calculator that merely has get_eigenvalues, this can raise TypeError for valid protocol implementations that don’t accept those keywords.

Proposed compatibility-safe patch
-        data, eigs = self.calculator.get_eigenvalues(data,
-                                                     nk=kwargs.get("nk", None),
-                                                     eig_solver=kwargs.get("eig_solver", None))
+        eigen_kwargs = {}
+        nk = kwargs.get("nk", None)
+        eig_solver = kwargs.get("eig_solver", kwargs.get("solver", None))
+        if nk is not None:
+            eigen_kwargs["nk"] = nk
+        if eig_solver is not None:
+            eigen_kwargs["eig_solver"] = eig_solver
+        data, eigs = self.calculator.get_eigenvalues(data, **eigen_kwargs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data, eigs = self.calculator.get_eigenvalues(data,
nk=kwargs.get("nk", None),
eig_solver=kwargs.get("eig_solver", None))
eigen_kwargs = {}
nk = kwargs.get("nk", None)
eig_solver = kwargs.get("eig_solver", kwargs.get("solver", None))
if nk is not None:
eigen_kwargs["nk"] = nk
if eig_solver is not None:
eigen_kwargs["eig_solver"] = eig_solver
data, eigs = self.calculator.get_eigenvalues(data, **eigen_kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@dptb/postprocess/unified/system.py` around lines 243 - 245, The call to
self.calculator.get_eigenvalues currently injects nk and eig_solver
unconditionally (using kwargs.get("nk", None) and kwargs.get("eig_solver",
None)), which can cause TypeError for calculator implementations that don't
accept those keywords; change the call to build a filtered kwargs dict (e.g.,
include "nk" and "eig_solver" only if kwargs contains non-None values) and pass
that dict (or explicitly pass only the provided parameters) to
self.calculator.get_eigenvalues so you only supply supported keywords.


calculated_efermi = self.estimate_efermi_e(
eigenvalues=eigs.detach().numpy(),
Expand Down Expand Up @@ -294,7 +296,10 @@ def estimate_efermi_e(self, eigenvalues,
)


def get_bands(self, kpath_config: Optional[dict] = None, reuse: Optional[bool]=True, **kwargs):
def get_bands(self,
kpath_config: Optional[dict] = None,
reuse: Optional[bool]=True,
**kwargs):
# 计算能带,返回 bands
# bands 应该是一个类,也有属性。bands.kpoints, bands.eigenvalues, bands.klabels, bands.kticks, 也有函数 bands.plot()
if self.has_bands and reuse:
Expand All @@ -303,7 +308,8 @@ def get_bands(self, kpath_config: Optional[dict] = None, reuse: Optional[bool]=T
assert kpath_config is not None, "kpath_config must be provided if bands not calculated."
self._bands = BandAccessor(self)
self._bands.set_kpath(**kpath_config)
self._bands.compute()
self._bands.compute(nk=kwargs.get("nk", None),
eig_solver=kwargs.get("eig_solver", None))
self.has_bands = True
return self._bands

Expand Down