Skip to content

Commit 7138261

Browse files
authored
Input Validation with Pydantic (#164)
This PR introuces the NbedConfig pydantic model, with stricter type validation, which is passed to the NbedDriver. Additionally, enums for Projector and Localizer are added the config objects file.
1 parent 30a21a9 commit 7138261

File tree

15 files changed

+628
-484
lines changed

15 files changed

+628
-484
lines changed

CHANGELOG.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## Unreleased
9+
### Added
10+
- `NbedConfig` pydantic model to validate user input.
11+
812
## [0.0.8]
9-
## Fixed
13+
### Fixed
1014
- `SPADELocalizer` now outputs whole c matrix when virtual localization is stopped early.
1115
- `ACELocalizer` was returning 1 too few moleucular orbitals.
1216
- Fixed a bug causing embedded FCI calculations to fail for open shell systems.
1317

14-
## Changed
18+
### Changed
1519
- 'nbed.scf.huzinaga_hf' and 'nbed.scf.huzinaga_rks' cmbined into 'nbed.scf.huzinaga_scf'
1620
- Combined `scf/huzinaga_` HF and KS methods into `huzinaga_scf`
1721
- python version requirement changed to `>=3.11, <4`
@@ -20,11 +24,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2024
- `localizers` now comprised of `occupied` and `virtual`, with `Localizer` now `OccupiedLocalizer`
2125
- concentric localization moved from `SPADELocalizer` to its own class `ConcentricLocalizer(VirtualLocalizer)`
2226

23-
## Added
27+
### Added
2428
- `.pre-commit-config.yaml` added
2529
- added `ACELocalizer` which implements ace-of-spade method for multiple reaction geometries.
2630

27-
## Removed
31+
### Removed
2832
- `mol_plot.py` removed as not required for/by main uses of package
2933
- dropped support for Pennylane, as they are pinned to numpy <2
3034
- Removed function to convert from fermionic hamiltonian to qubit hamiltonian, which was in `ham_builder.py`.

docs/notebooks/publications/A Scalable Approach to Quantum Simulation via Projection-based Embedding.ipynb

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@
238238
},
239239
{
240240
"cell_type": "code",
241-
"execution_count": 4,
241+
"execution_count": null,
242242
"metadata": {},
243243
"outputs": [],
244244
"source": [
@@ -337,13 +337,13 @@
337337
" transform=transform,\n",
338338
" )\n",
339339
" qham = huz_builder.build()\n",
340-
" result[active][\"huz\"] = {}\n",
341-
" result[active][\"huz\"][\"qham\"] = HamiltonianConverter(qham)._intermediate\n",
342-
" result[active][\"huz\"][\"terms\"] = len(qham.terms)\n",
343-
" result[active][\"huz\"][\"n_qubits\"] = count_qubits(qham)\n",
344-
" result[active][\"huz\"][\"classical_energy\"] = driver._huzinaga[\"classical_energy\"]\n",
345-
" result[active][\"huz\"][\"ground\"] = None\n",
346-
" result[active][\"huz\"][\"e_ccsd\"] = driver._huzinaga[\"e_ccsd\"]\n",
340+
" result[active][\"huzinaga\"] = {}\n",
341+
" result[active][\"huzinaga\"][\"qham\"] = HamiltonianConverter(qham)._intermediate\n",
342+
" result[active][\"huzinaga\"][\"terms\"] = len(qham.terms)\n",
343+
" result[active][\"huzinaga\"][\"n_qubits\"] = count_qubits(qham)\n",
344+
" result[active][\"huzinaga\"][\"classical_energy\"] = driver._huzinaga[\"classical_energy\"]\n",
345+
" result[active][\"huzinaga\"][\"ground\"] = None\n",
346+
" result[active][\"huzinaga\"][\"e_ccsd\"] = driver._huzinaga[\"e_ccsd\"]\n",
347347
" print(\"Huzinaga finished.\")\n",
348348
"\n",
349349
" # untapered_mu = mu_builder.build(taper=False)\n",
@@ -366,7 +366,7 @@
366366
},
367367
{
368368
"cell_type": "code",
369-
"execution_count": 5,
369+
"execution_count": null,
370370
"metadata": {},
371371
"outputs": [],
372372
"source": [
@@ -382,11 +382,11 @@
382382
" embeddings = pd.concat([threes, twos], axis=0)\n",
383383
" full_vals = pd.DataFrame([v for v in df[\"full\"].to_list()], index=df[\"mol_name\"])\n",
384384
" mu_vals = pd.DataFrame([v for v in embeddings[\"mu\"]], index=embeddings.index)\n",
385-
" huz_vals = pd.DataFrame([v for v in embeddings[\"huz\"]], index=embeddings.index)\n",
385+
" huz_vals = pd.DataFrame([v for v in embeddings[\"huzinaga\"]], index=embeddings.index)\n",
386386
"\n",
387387
" energies = pd.concat(\n",
388388
" [df[\"e_dft\"], full_vals[\"e_ccsd\"], mu_vals[\"e_ccsd\"], huz_vals[\"e_ccsd\"]],\n",
389-
" keys=[\"DFT\", \"Full\", \"Mu\", \"Huz\"],\n",
389+
" keys=[\"DFT\", \"Full\", \"Mu\", \"huzinaga\"],\n",
390390
" axis=1,\n",
391391
" )\n",
392392
" energies[\"dft_diffs\"] = (\n",
@@ -396,7 +396,7 @@
396396
" (energies[\"Mu\"] - energies[\"Full\"]) / energies[\"Full\"]\n",
397397
" ).apply(lambda x: np.log10(abs(x)))\n",
398398
" energies[\"huz_diffs\"] = (\n",
399-
" (energies[\"Huz\"] - energies[\"Full\"]) / energies[\"Full\"]\n",
399+
" (energies[\"huzinaga\"] - energies[\"Full\"]) / energies[\"Full\"]\n",
400400
" ).apply(lambda x: np.log10(abs(x)))\n",
401401
" energies = energies.reindex(\n",
402402
" [\n",
@@ -581,7 +581,7 @@
581581
},
582582
{
583583
"cell_type": "code",
584-
"execution_count": 7,
584+
"execution_count": null,
585585
"metadata": {},
586586
"outputs": [],
587587
"source": [
@@ -591,15 +591,15 @@
591591
" print(\"\\nQUBITS\")\n",
592592
" qubits = pd.concat(\n",
593593
" [full_vals[\"n_qubits\"], mu_vals[\"n_qubits\"], huz_vals[\"n_qubits\"]],\n",
594-
" keys=[\"Full\", \"Mu\", \"Huz\"],\n",
594+
" keys=[\"Full\", \"Mu\", \"huzinaga\"],\n",
595595
" axis=1,\n",
596596
" )\n",
597597
" print(qubits)\n",
598598
"\n",
599599
" print(\"\\nTERMS\")\n",
600600
" terms = pd.concat(\n",
601601
" [full_vals[\"terms\"], mu_vals[\"terms\"], huz_vals[\"terms\"]],\n",
602-
" keys=[\"Full\", \"Mu\", \"Huz\"],\n",
602+
" keys=[\"Full\", \"Mu\", \"huzinaga\"],\n",
603603
" axis=1,\n",
604604
" )\n",
605605
" print(terms)\n",
@@ -615,11 +615,11 @@
615615
" print(\"\\nMolecule Results\")\n",
616616
" mol_results = pd.concat(\n",
617617
" [\n",
618-
" energies[\"Full\"] - energies[\"Huz\"],\n",
618+
" energies[\"Full\"] - energies[\"huzinaga\"],\n",
619619
" energies[\"Full\"] - energies[\"Mu\"],\n",
620-
" qubits[\"Huz\"],\n",
620+
" qubits[\"huzinaga\"],\n",
621621
" qubits[\"Mu\"],\n",
622-
" terms[\"Huz\"],\n",
622+
" terms[\"huzinaga\"],\n",
623623
" terms[\"Mu\"],\n",
624624
" ],\n",
625625
" axis=1,\n",
@@ -1122,7 +1122,7 @@
11221122
},
11231123
{
11241124
"cell_type": "code",
1125-
"execution_count": 13,
1125+
"execution_count": null,
11261126
"metadata": {},
11271127
"outputs": [],
11281128
"source": [
@@ -1206,14 +1206,14 @@
12061206
" transform=transform,\n",
12071207
" )\n",
12081208
" qham = huz_builder.build(qubits, taper=False)\n",
1209-
" result[active][\"huz\"] = {}\n",
1210-
" result[active][\"huz\"][\"qham\"] = HamiltonianConverter(qham)._intermediate\n",
1211-
" result[active][\"huz\"][\"terms\"] = len(qham.terms)\n",
1212-
" result[active][\"huz\"][\"n_qubits\"] = count_qubits(qham)\n",
1213-
" result[active][\"huz\"][\"classical_energy\"] = driver._huzinaga[\"classical_energy\"]\n",
1214-
" result[active][\"huz\"][\"ground\"] = None\n",
1215-
" result[active][\"huz\"][\"e_ccsd\"] = driver._huzinaga[\"e_ccsd\"]\n",
1216-
" result[active][\"huz\"][\"nmos\"] = len(driver.localized_system.active_MO_inds)\n",
1209+
" result[active][\"huzinaga\"] = {}\n",
1210+
" result[active][\"huzinaga\"][\"qham\"] = HamiltonianConverter(qham)._intermediate\n",
1211+
" result[active][\"huzinaga\"][\"terms\"] = len(qham.terms)\n",
1212+
" result[active][\"huzinaga\"][\"n_qubits\"] = count_qubits(qham)\n",
1213+
" result[active][\"huzinaga\"][\"classical_energy\"] = driver._huzinaga[\"classical_energy\"]\n",
1214+
" result[active][\"huzinaga\"][\"ground\"] = None\n",
1215+
" result[active][\"huzinaga\"][\"e_ccsd\"] = driver._huzinaga[\"e_ccsd\"]\n",
1216+
" result[active][\"huzinaga\"][\"nmos\"] = len(driver.localized_system.active_MO_inds)\n",
12171217
" print(\"Huzinaga finished.\")\n",
12181218
"\n",
12191219
" # untapered_mu = mu_builder.build(taper=False)\n",
@@ -1236,7 +1236,7 @@
12361236
},
12371237
{
12381238
"cell_type": "code",
1239-
"execution_count": 14,
1239+
"execution_count": null,
12401240
"metadata": {},
12411241
"outputs": [],
12421242
"source": [
@@ -1245,15 +1245,15 @@
12451245
" active_atoms = range(1, 6)\n",
12461246
" mu_qubits = [cyclopentane[str(i)][\"mu\"][\"n_qubits\"] for i in active_atoms]\n",
12471247
" mu_terms = [cyclopentane[str(i)][\"mu\"][\"terms\"] for i in active_atoms]\n",
1248-
" huz_qubits = [cyclopentane[str(i)][\"huz\"][\"n_qubits\"] for i in active_atoms]\n",
1249-
" huz_terms = [cyclopentane[str(i)][\"huz\"][\"terms\"] for i in active_atoms]\n",
1248+
" huz_qubits = [cyclopentane[str(i)][\"huzinaga\"][\"n_qubits\"] for i in active_atoms]\n",
1249+
" huz_terms = [cyclopentane[str(i)][\"huzinaga\"][\"terms\"] for i in active_atoms]\n",
12501250
" full_terms = cyclopentane[\"full\"][\"terms\"]\n",
12511251
" full_n_qubits = cyclopentane[\"full\"][\"n_qubits\"]\n",
12521252
" full_nmos = cyclopentane[\"full\"][\"nmos\"]\n",
12531253
" mu_energies = [cyclopentane[str(i)][\"mu\"][\"e_ccsd\"] for i in active_atoms]\n",
1254-
" huz_energies = [cyclopentane[str(i)][\"huz\"][\"e_ccsd\"] for i in active_atoms]\n",
1254+
" huz_energies = [cyclopentane[str(i)][\"huzinaga\"][\"e_ccsd\"] for i in active_atoms]\n",
12551255
" mu_orbitals = [cyclopentane[str(i)][\"mu\"][\"nmos\"] for i in active_atoms]\n",
1256-
" huz_orbitals = [cyclopentane[str(i)][\"huz\"][\"nmos\"] for i in active_atoms]\n",
1256+
" huz_orbitals = [cyclopentane[str(i)][\"huzinaga\"][\"nmos\"] for i in active_atoms]\n",
12571257
"\n",
12581258
" active_atoms = [0, *active_atoms]\n",
12591259
"\n",
@@ -1944,7 +1944,7 @@
19441944
},
19451945
{
19461946
"cell_type": "code",
1947-
"execution_count": 17,
1947+
"execution_count": null,
19481948
"metadata": {},
19491949
"outputs": [
19501950
{
@@ -2008,8 +2008,8 @@
20082008
" if n_data:\n",
20092009
" if n_data[\"mu\"].get(\"qham\", False):\n",
20102010
" n_data[\"mu\"].pop(\"qham\")\n",
2011-
" if n_data[\"huz\"].get(\"qham\", False):\n",
2012-
" n_data[\"huz\"].pop(\"qham\")"
2011+
" if n_data[\"huzinaga\"].get(\"qham\", False):\n",
2012+
" n_data[\"huzinaga\"].pop(\"qham\")"
20132013
]
20142014
},
20152015
{
@@ -3899,7 +3899,7 @@
38993899
],
39003900
"metadata": {
39013901
"kernelspec": {
3902-
"display_name": "nbed-1_9TTDE1-py3.10",
3902+
"display_name": ".venv",
39033903
"language": "python",
39043904
"name": "python3"
39053905
},
@@ -3913,7 +3913,7 @@
39133913
"name": "python",
39143914
"nbconvert_exporter": "python",
39153915
"pygments_lexer": "ipython3",
3916-
"version": "3.10.11"
3916+
"version": "3.13.1"
39173917
}
39183918
},
39193919
"nbformat": 4,

docs/source/config.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Configuration
2+
-------------
3+
4+
Input data is validated against a Pydantic model, in the `NbedConfig` class. This is then passed to the `NbedDriver`.
5+
6+
.. automodule:: nbed.config
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:

nbed/config.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Custom Types and Enums."""
2+
3+
import os
4+
from enum import Enum
5+
from typing import Annotated, Any
6+
7+
from pydantic import (
8+
BaseModel,
9+
BeforeValidator,
10+
Field,
11+
NonNegativeInt,
12+
PositiveFloat,
13+
PositiveInt,
14+
TypeAdapter,
15+
)
16+
17+
18+
class Projector(Enum):
19+
"""Implemented Projectors."""
20+
21+
MU = "mu"
22+
HUZ = "huzinaga"
23+
BOTH = "both"
24+
25+
26+
class Localizer(Enum):
27+
"""Implemented Occupied Localizers."""
28+
29+
SPADE = "spade"
30+
BOYS = "boys"
31+
IBO = "ibo"
32+
PM = "pm"
33+
34+
35+
XYZGeometry = Annotated[
36+
str, Field(pattern="^\\d+\n\\s?\n(?:\\w(?:\\s+\\-?\\d\\.\\d+){3}\n?)*")
37+
]
38+
39+
40+
def validate_xyz_file(maybe_xyz: Any) -> str:
41+
"""Validates the the filepath given leads to a valid XYZ formatted file.
42+
43+
Args:
44+
maybe_xyz (Any): A path to an existing file.
45+
46+
Returns:
47+
str: an XYZ geometry string.
48+
"""
49+
if os.path.exists(maybe_xyz):
50+
with open(maybe_xyz) as file:
51+
content = file.read()
52+
TypeAdapter(XYZGeometry).validate_strings(content)
53+
return content
54+
else:
55+
return maybe_xyz
56+
57+
58+
class NbedConfig(BaseModel):
59+
"""Config for Nbed.
60+
61+
Args:
62+
geometry (str): Path to .xyz file containing molecular geometry or raw xyz string.
63+
n_active_atoms (int): The number of atoms to include in the active region.
64+
basis (str): The name of an atomic orbital basis set to use for chemistry calculations.
65+
xc_functional (str): The name of an Exchange-Correlation functional to be used for DFT.
66+
projector (str): Projector to screen out environment orbitals, One of 'mu' or 'huzinaga'.
67+
localization (str): Orbital localization method to use. One of 'spade', 'pipek-mezey', 'boys' or 'ibo'.
68+
convergence (float): The convergence tolerance for energy calculations.
69+
charge (int): Charge of molecular species
70+
mu_level_shift (float): Level shift parameter to use for mu-projector.
71+
run_ccsd_emb (bool): Whether or not to find the CCSD energy of embbeded system for reference.
72+
run_fci_emb (bool): Whether or not to find the FCI energy of embbeded system for reference.
73+
run_virtual_localization (bool): Whether or not to localize virtual orbitals.
74+
n_mo_overwrite (tuple[None| int, None | int]): Optional overwrite values for occupied localizers.
75+
max_ram_memory (int): Amount of RAM memery in MB available for PySCF calculation
76+
pyscf_print_level (int): Amount of information PySCF prints
77+
unit (str): molecular geometry unit 'Angstrom' or 'Bohr'
78+
max_hf_cycles (int): max number of Hartree-Fock iterations allowed (for global and local HFock)
79+
max_dft_cycles (int): max number of DFT iterations allowed in scf calc
80+
init_huzinaga_rhf_with_mu (bool): Hidden flag to seed huzinaga RHF with mu shift result (for developers only)
81+
"""
82+
83+
geometry: Annotated[XYZGeometry, BeforeValidator(validate_xyz_file)]
84+
n_active_atoms: PositiveInt
85+
basis: str
86+
xc_functional: str
87+
projector: Projector = Field(default=Projector.MU)
88+
localization: Localizer = Field(default=Localizer.SPADE)
89+
convergence: PositiveFloat = 1e-6
90+
charge: NonNegativeInt = Field(default=0)
91+
spin: NonNegativeInt = Field(default=0)
92+
unit: str = "angstrom"
93+
symmetry: bool = False
94+
mu_level_shift: PositiveFloat = 1e6
95+
run_ccsd_emb: bool = False
96+
run_fci_emb: bool = False
97+
run_virtual_localization: bool = True
98+
run_dft_in_dft: bool = False
99+
n_mo_overwrite: tuple[None | NonNegativeInt, None | NonNegativeInt] = (None, None)
100+
max_ram_memory: PositiveInt = 4000
101+
occupied_threshold: float = Field(default=0.95, gt=0, lt=1)
102+
virtual_threshold: float = Field(default=0.95, gt=0, lt=1)
103+
max_shells: PositiveInt = 4
104+
init_huzinaga_rhf_with_mu: bool = False
105+
max_hf_cycles: PositiveInt = Field(default=50)
106+
max_dft_cycles: PositiveInt = Field(default=50)
107+
force_unrestricted: bool = False
108+
mm_coords: list | None = None
109+
mm_charges: list | None = None
110+
mm_radii: list | None = None

0 commit comments

Comments
 (0)