77from pyscf .lib import StreamObject
88from scipy .linalg import fractional_matrix_power
99
10- from .base import VirtualLocalizer
10+ from nbed . localizers . virtual .base import VirtualLocalizer
1111
1212logger = logging .getLogger (__name__ )
1313
@@ -19,7 +19,7 @@ def __init__(
1919 self ,
2020 embedded_scf : StreamObject ,
2121 n_active_atoms : int ,
22- c_loc_occ : tuple ( NDArray , NDArray | None ) ,
22+ c_loc_occ : tuple [ NDArray , NDArray | None ] ,
2323 norm_cutoff : float = 0.05 ,
2424 overlap_cutoff = 1e-5 ,
2525 ):
@@ -38,56 +38,70 @@ def localize_virtual(self) -> StreamObject:
3838
3939 if self .c_loc_occ [1 ] is None :
4040 logger .debug ("Runing PAO for spinless system." )
41- virtuals = _localize__spin_pao (
41+ virtuals = _localize_virtual_spin_pao (
4242 self .c_loc_occ [0 ],
4343 ao_overlap ,
4444 n_act_aos ,
4545 self .norm_cutoff ,
4646 self .overlap_cutoff ,
4747 )
48-
48+ logger . debug ( f" { virtuals . shape = } " )
4949 occ_mo_coeff = self .embedded_scf .mo_coeff [:, self .embedded_scf .mo_occ > 0 ]
50- self .embedded_scf .mo_coeff = np .hstack ((occ_mo_coeff , virtuals ))
51- self .embedded_scf .mo_occ = self .embedded_scf .mo_occ [
52- : self .embedded_scf .mo_coeff .shape [- 1 ]
53- ]
50+ mo_coeff = np .hstack ((occ_mo_coeff , virtuals ))
51+
52+ n_mos = mo_coeff .shape [- 1 ]
53+ n_occ = np .count_nonzero (self .embedded_scf .mo_occ [0 ])
54+ mo_occ = np .array (n_occ * [1 ] + (n_mos - n_occ ) * [0 ])
5455
5556 else : # Restricted open shell
5657 logger .debug ("Running PAO for each spin separately." )
57- alpha_virtuals = _localize__spin_pao (
58+ alpha_virtuals = _localize_virtual_spin_pao (
5859 self .c_loc_occ [0 ],
5960 ao_overlap ,
6061 n_act_aos ,
6162 self .norm_cutoff ,
6263 self .overlap_cutoff ,
6364 )
64- beta_virtuals = _localize__spin_pao (
65+ beta_virtuals = _localize_virtual_spin_pao (
6566 self .c_loc_occ [1 ],
6667 ao_overlap ,
6768 n_act_aos ,
6869 self .norm_cutoff ,
6970 self .overlap_cutoff ,
7071 )
71-
72- occ_mo_coeff = self .embedded_scf .mo_coeff [
73- :, :, self .embedded_scf .mo_occ > 0
72+ logger .debug (f"{ alpha_virtuals .shape = } " )
73+ logger .debug (f"{ beta_virtuals .shape = } " )
74+ alpha_occ_mo_coeff = self .embedded_scf .mo_coeff [0 ][
75+ :, self .embedded_scf .mo_occ [0 ] > 0
76+ ]
77+ beta_occ_mo_coeff = self .embedded_scf .mo_coeff [1 ][
78+ :, self .embedded_scf .mo_occ [1 ] > 0
7479 ]
75- self .embedded_scf .mo_coeff = np .vstack (
80+ mo_coeff = np .array (
81+ [
82+ np .hstack ((alpha_occ_mo_coeff , alpha_virtuals )),
83+ np .hstack ((beta_occ_mo_coeff , beta_virtuals )),
84+ ]
85+ )
86+ n_mos = mo_coeff .shape [- 1 ]
87+ alpha_n_occ = np .count_nonzero (self .embedded_scf .mo_occ [0 ])
88+ beta_n_occ = np .count_nonzero (self .embedded_scf .mo_occ [1 ])
89+ mo_occ = np .vstack (
7690 (
77- np .hstack (( occ_mo_coeff [ 0 ], alpha_virtuals ) ),
78- np .hstack (( occ_mo_coeff [1 ], beta_virtuals ) ),
91+ np .array ( alpha_n_occ * [ 1 ] + ( n_mos - alpha_n_occ ) * [ 0 ] ),
92+ np .array ( beta_n_occ * [1 ] + ( n_mos - beta_n_occ ) * [ 0 ] ),
7993 )
8094 )
81- self . embedded_scf . mo_occ = self . embedded_scf . mo_occ [
82- :, : self .embedded_scf .mo_coeff . shape [ - 1 ]
83- ]
95+
96+ self .embedded_scf .mo_coeff = mo_coeff
97+ self . embedded_scf . mo_occ = mo_occ
8498
8599 return self .embedded_scf
86100
87101 # where should the cutoff values come from?
88102
89103
90- def _localize__spin_pao (
104+ def _localize_virtual_spin_pao (
91105 c_loc_occ : NDArray ,
92106 ao_overlap : NDArray ,
93107 n_act_aos : int ,
@@ -124,6 +138,7 @@ def _localize__spin_pao(
124138
125139 renormalized_paos = s_half @ truncated_paos
126140 logger .debug (f"{ renormalized_paos = } " )
141+ logger .debug (f"{ renormalized_paos .shape = } " )
127142 logger .debug (f"{ np .einsum ("ij,ij->j" , renormalized_paos , renormalized_paos )= } " )
128143 renormalized_paos = renormalized_paos / np .sqrt (
129144 np .einsum ("ij,ij->j" , renormalized_paos , renormalized_paos )
@@ -134,14 +149,22 @@ def _localize__spin_pao(
134149
135150 diagonalized_overlap = renormalized_paos .T @ ao_overlap @ renormalized_paos
136151
152+ logger .debug (f"{ diagonalized_overlap .shape = } " )
137153 logger .debug (f"{ diagonalized_overlap = } " )
138154
139155 eigvals , eigvecs = np .linalg .eigh (diagonalized_overlap )
140156
141- logger .debug (eigvals )
157+ logger .debug (f"Overlap eigenvalues { eigvals } " )
158+ logger .debug (f"{ overlap_cutoff = } " )
142159
160+ logger .debug (f"{ eigvecs .shape = } " )
143161 # How to transform the truncated paos?
144- final_paos = eigvecs [:, eigvals > overlap_cutoff ]
145- logger .debug (final_paos )
146-
162+ final_paos = renormalized_paos [:, eigvals > overlap_cutoff ]
163+ logger .debug (f"{ final_paos = } " )
164+
165+ if final_paos .shape [- 1 ] == 0 :
166+ logger .warning ("No projected atomic orbitals!" )
167+ logger .warning (
168+ "This suggests your active region has no virtual Atomic Orbitals."
169+ )
147170 return final_paos
0 commit comments