Skip to content

Commit d8c54ba

Browse files
authored
Allow overwriting arguments in interface (#167)
# Allow overwriting arguments in nbed function. Where keyword arguments exist these are used to overwrite existing values, all inputs are then revalidated. Doing this makes it much easier to quickly try options or programatically update inputs to `nbed()`, while ensuring that the NbedDriver class stays simple. # Seperates out embedded-fci and dft-in-dft functions from driver class for ease of testing. Added in some tests for dft-in-dft, ensuring that the energy matches whole ssytem dft where the same exchange correlation functional is used.
1 parent fb4558e commit d8c54ba

File tree

8 files changed

+251
-158
lines changed

8 files changed

+251
-158
lines changed

nbed/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Init for Nbed package."""
22

3+
from .config import NbedConfig
34
from .embed import nbed
45
from .utils import setup_logs
56

6-
__all__ = ["nbed"]
7+
__all__ = ["nbed", "NbedConfig"]
78

89
setup_logs()

nbed/driver.py

Lines changed: 204 additions & 147 deletions
Large diffs are not rendered by default.

nbed/embed.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,29 @@
1313
logger = logging.getLogger(__name__)
1414

1515

16+
def overwrite_config_kwargs(config: NbedConfig, **config_kwargs) -> NbedConfig:
17+
"""Overwrites config values with key-words and revalidates.
18+
19+
Args:
20+
config (NbedConfig): A config model.
21+
config_kwargs (dict): Any possible key-word arguments.
22+
23+
Returns:
24+
NbedConfig: A validated config model.
25+
26+
Raises:
27+
ValidationError: If key-word arguments provided are not part of model.
28+
"""
29+
if config_kwargs != {}:
30+
logger.info("Overwriting select field with additonal config.")
31+
config_dict = config.model_dump()
32+
for k, v in config_kwargs.items():
33+
config_dict[k] = v
34+
return NbedConfig(**config_dict)
35+
else:
36+
return config
37+
38+
1639
def nbed(
1740
config: NbedConfig | str | None = None,
1841
**config_kwargs,
@@ -29,15 +52,20 @@ def nbed(
2952
Returns:
3053
NbedDriver: An embedded driver.
3154
"""
55+
logger.info(f"Running Nbed with:\n\tconfig\t{config}\n\tkeywords\t{config_kwargs}")
3256
match config:
3357
case NbedConfig():
3458
logger.info("Using validated config.")
59+
config = overwrite_config_kwargs(config, **config_kwargs)
60+
3561
case str() | Path():
3662
logger.info("Using config file %s", config)
63+
logger.info("Validating config from file.")
3764
with open(FilePath(config)) as f:
38-
logger.info("Validating config from file.")
3965
data = json.load(f)
40-
config = NbedConfig(**data)
66+
config = NbedConfig(**data)
67+
config = overwrite_config_kwargs(config, **config_kwargs)
68+
4169
case None:
4270
logger.info("Validating config from passed arguments.")
4371
logger.debug(f"{config_kwargs=}")

nbed/ham_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def reduce_virtuals(scf_method, n_frozen_virt: int) -> lib.StreamObject:
278278
reduced_scf_method.mo_coeff = reduced_scf_method.mo_coeff[:, :, :-n_frozen_virt]
279279
reduced_scf_method.mo_occ = reduced_scf_method.mo_occ[:, :-n_frozen_virt]
280280

281-
elif isinstance(reduced_scf_method, (scf.hf.RHF, scf.rohf.ROHF)):
281+
elif isinstance(reduced_scf_method, (scf.hf.RHF)):
282282
reduced_scf_method.mo_coeff = reduced_scf_method.mo_coeff[:, :-n_frozen_virt]
283283
reduced_scf_method.mo_occ = reduced_scf_method.mo_occ[:-n_frozen_virt]
284284

nbed/localizers/ace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def localize_path(self) -> tuple[int, int]:
7070
logger.debug("Singular Values")
7171
logger.debug(singular_values)
7272

73-
if isinstance(scf_object, (scf.hf.RHF, dft.rks.RKS)):
73+
if isinstance(scf_object, (scf.rhf.RHF, dft.rks.RKS)):
7474
alpha = self.localize_spin([s[0] for s in singular_values])
7575
beta = alpha
7676
elif isinstance(scf_object, (scf.uhf.UHF, dft.uks.UKS)):

nbed/scf/huzinaga_scf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def huzinaga_scf(
169169
scf_energy = calculate_ks_energy(
170170
scf_method, dft_potential, density_matrix, huzinaga_op_std
171171
)
172-
elif isinstance(scf_method, (scf.hf.RHF, scf.uhf.UHF)):
172+
elif isinstance(scf_method, (scf.rhf.RHF, scf.uhf.UHF)):
173173
hamiltonian = (
174174
scf_method.get_hcore() + dft_potential + 0.5 * vhf + huzinaga_op_std
175175
)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def water_mol(water_filepath) -> gto.Mole:
3434

3535
@pytest.fixture(scope="module")
3636
def water_rhf(water_molecule) -> StreamObject:
37-
rhf = scf.RHF(water_molecule)
37+
rhf = scf.rhf.RHF(water_molecule)
3838
rhf.kernel()
3939
return rhf
4040

tests/test_driver.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,25 @@ def huz_unrestricted_driver(nbed_config) -> NbedDriver:
4343
driver.embed()
4444
return driver
4545

46+
def test_restricted_dft_in_dft(mu_driver, huz_driver):
47+
mu_did = mu_driver._dft_in_dft(ProjectorEnum.MU)
48+
huz_did = huz_driver._dft_in_dft(ProjectorEnum.HUZ)
49+
assert np.isclose(mu_did["e_dft_in_dft"], mu_driver._global_ks().e_tot)
50+
assert np.isclose(huz_did["e_dft_in_dft"], huz_driver._global_ks().e_tot)
51+
assert np.isclose(mu_did["e_dft_in_dft"], huz_did["e_dft_in_dft"])
52+
4653
def test_embedded_fci(nbed_config, mu_driver, mu_unrestricted_driver, huz_driver, huz_unrestricted_driver):
47-
assert(np.isclose(mu_driver._run_emb_FCI(mu_driver.embedded_scf).e_tot, -62.261794716560416))
48-
assert(np.isclose(mu_unrestricted_driver._run_emb_FCI(mu_unrestricted_driver.embedded_scf).e_tot, -62.261794716560416))
54+
assert(np.isclose(mu_driver._run_emb_fci(mu_driver.embedded_scf).e_tot, -62.261794716560416))
55+
assert(np.isclose(mu_unrestricted_driver._run_emb_fci(mu_unrestricted_driver.embedded_scf).e_tot, -62.261794716560416))
4956
nbed_config.projector = ProjectorEnum.HUZ
5057
nbed_config.n_active_atoms=1
5158
huz_driver = NbedDriver(nbed_config)
5259
huz_driver.embed()
53-
assert(np.isclose(huz_driver._run_emb_FCI(huz_driver.embedded_scf).e_tot, -51.61379094995273))
60+
assert(np.isclose(huz_driver._run_emb_fci(huz_driver.embedded_scf).e_tot, -51.61379094995273))
5461
nbed_config.force_unrestricted = True
5562
huz_unrestricted_driver = NbedDriver(nbed_config)
5663
huz_unrestricted_driver.embed()
57-
assert(np.isclose(huz_unrestricted_driver._run_emb_FCI(huz_unrestricted_driver.embedded_scf).e_tot, -51.61379094995273))
64+
assert(np.isclose(huz_unrestricted_driver._run_emb_fci(huz_unrestricted_driver.embedded_scf).e_tot, -51.61379094995273))
5865

5966
def test_restricted_projector_results_match(mu_driver, huz_driver) -> None:
6067
assert mu_driver._mu is not {} and mu_driver._huzinaga is None

0 commit comments

Comments
 (0)