Skip to content

Commit 74a8649

Browse files
Merge pull request #27 from AsymmetryChou/rgf_acc
Implement batched RGF refactorization and related unit tests
2 parents 06278bc + 79e852d commit 74a8649

8 files changed

Lines changed: 815 additions & 279 deletions

File tree

dpnegf/negf/device_property.py

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@
1616
"""
1717
log = logging.getLogger(__name__)
1818

19+
20+
def _build_s_in_batched(hd, seinL, seinR, idx0, idy0, idx1, idy1):
21+
'''Allocate per-block [B, n_q, n_q] zeros and inject the corner self-energy slices.
22+
23+
Mirrors the scalar construction in cal_green_function (the seinL/seinR fed in
24+
here are the already-scaled 1j*(seL - seL.mH) * f tensors of shape [B,n,n]).
25+
'''
26+
B = seinL.shape[0]
27+
s_in = [torch.zeros((B,) + tuple(blk.shape), dtype=torch.complex128) for blk in hd]
28+
s_in[0][:, :idx0, :idy0] = s_in[0][:, :idx0, :idy0] + seinL[:, :idx0, :idy0]
29+
s_in[-1][:, -idx1:, -idy1:] = s_in[-1][:, -idx1:, -idy1:] + seinR[:, -idx1:, -idy1:]
30+
return s_in
31+
32+
1933
class DeviceProperty(object):
2034
'''Device object for NEGF calculation
2135
@@ -150,8 +164,12 @@ def cal_green_function(self, energy, kpoint, eta_device=0., block_tridiagonal=Tr
150164
A boolean parameter that indicates whether the last column blocks of the retarded Green's function are needed.
151165
'''
152166
assert len(np.array(kpoint).reshape(-1)) == 3
153-
if not isinstance(energy, torch.Tensor):
154-
energy = torch.tensor(energy, dtype=torch.complex128)
167+
energy = torch.as_tensor(energy, dtype=torch.complex128)
168+
if energy.ndim == 0:
169+
energy = energy.reshape(1)
170+
assert energy.ndim == 1, f"energy must be 0-d, scalar, or 1-D [B]; got shape {tuple(energy.shape)}"
171+
B = energy.shape[0]
172+
batched_mode = B > 1
155173

156174
self.block_tridiagonal = block_tridiagonal
157175
if self.kpoint is None or abs(self.kpoint - torch.tensor(kpoint)).sum() > 1e-5:
@@ -200,33 +218,45 @@ def cal_green_function(self, energy, kpoint, eta_device=0., block_tridiagonal=Tr
200218

201219
seL = self.lead_L.se
202220
seR = self.lead_R.se
221+
if batched_mode:
222+
assert seL.ndim == 3 and seR.ndim == 3, f"In batched mode, the self-energy should have shape [B,n,n], but got {seL.shape} and {seR.shape}"
223+
else:
224+
assert seL.ndim == 2 and seR.ndim == 2, f"In non-batched mode, the self-energy should have shape [n,n], but got {seL.shape} and {seR.shape}"
225+
203226
s01, s02 = self.hd[0].shape # The shape of the first H block
204-
se01, se02 = seL.shape # The shape of the left self-energy
227+
se01, se02 = seL.shape[-2], seL.shape[-1] # last two dims work for [n,n] and [B,n,n]
205228
s11, s12 = self.hd[-1].shape
206-
se11, se12 = seR.shape
229+
se11, se12 = seR.shape[-2], seR.shape[-1]
207230
idx0, idy0 = min(s01, se01), min(s02, se02)
208231
idx1, idy1 = min(s11, se11), min(s12, se12)
209232
if block_tridiagonal:
210233
# Based on the block tridiagonal algorithm, the shape of the self-energy should be
211-
# equal to or larger than the corresponding Hamiltonian block
234+
# equal to or lesser than the corresponding Hamiltonian block
212235
if se01 > s01 or se02 > s02:
213236
log.warning(f"The shape of left self-energy ({se01},{se02}) is larger than\
214237
the first Hamiltonian block ({s01},{s02}).")
215238
raise ValueError("Left Lead Self Energy size is larger than the first Hamiltonian Block.")
216239
if se11 > s11 or se12 > s12:
217-
log.warning(f"The shape of right self-energy ({se11},{se12}) is different from\
240+
log.warning(f"The shape of right self-energy ({se11},{se12}) is larger than\
218241
the last Hamiltonian block ({s11},{s12}).")
219242
raise ValueError("Right Lead Self Energy size is larger than the last Hamiltonian Block.")
220243

221244
green_funcs = {}
222245

223246
if need_lesser:
224247
# Fluctuation-Dissipation theorem; only build s_in when the lesser GF is consumed
225-
seinL = 1j*(seL-seL.conj().T) * self.lead_L.fermi_dirac(energy+self.E_ref).reshape(-1)
226-
seinR = 1j*(seR-seR.conj().T) * self.lead_R.fermi_dirac(energy+self.E_ref).reshape(-1)
227-
s_in = [torch.zeros(i.shape).cdouble() for i in self.hd]
228-
s_in[0][:idx0,:idy0] = s_in[0][:idx0,:idy0] + seinL[:idx0,:idy0]
229-
s_in[-1][-idx1:,-idy1:] = s_in[-1][-idx1:,-idy1:] + seinR[-idx1:,-idy1:]
248+
if batched_mode:
249+
fL = self.lead_L.fermi_dirac(energy + self.E_ref).reshape(B, 1, 1)
250+
fR = self.lead_R.fermi_dirac(energy + self.E_ref).reshape(B, 1, 1)
251+
seinL = 1j * (seL - seL.mH) * fL
252+
seinR = 1j * (seR - seR.mH) * fR
253+
s_in = _build_s_in_batched(self.hd, seinL, seinR, idx0, idy0, idx1, idy1)
254+
else:
255+
seinL = 1j*(seL-seL.conj().T) * self.lead_L.fermi_dirac(energy+self.E_ref).reshape(-1)
256+
seinR = 1j*(seR-seR.conj().T) * self.lead_R.fermi_dirac(energy+self.E_ref).reshape(-1)
257+
s_in = [torch.zeros(i.shape).cdouble() for i in self.hd]
258+
s_in[0][:idx0,:idy0] = s_in[0][:idx0,:idy0] + seinL[:idx0,:idy0]
259+
s_in[-1][-idx1:,-idy1:] = s_in[-1][-idx1:,-idy1:] + seinR[-idx1:,-idy1:]
230260
else:
231261
s_in = 0
232262

@@ -322,26 +352,36 @@ def _cal_current_nscf_(self, energy_grid, tc):
322352

323353

324354
def _cal_tc_(self):
325-
'''calculate the transmission coefficient
326-
355+
'''calculate the transmission coefficient
356+
327357
Returns
328358
-------
329-
tc is the transmission coefficient
330-
359+
tc is the transmission coefficient
360+
331361
'''
332362

333-
tx, ty = self.g_trans.shape
334-
lx, ly = self.lead_L.se.shape
335-
rx, ry = self.lead_R.se.shape
363+
g_trans = self.g_trans
364+
batched = g_trans.ndim == 3
365+
tx, ty = g_trans.shape[-2], g_trans.shape[-1]
366+
gammaL_full = self.lead_L.gamma
367+
gammaR_full = self.lead_R.gamma
368+
lx = gammaL_full.shape[-2]
369+
rx = gammaR_full.shape[-2]
336370
x0 = min(lx, tx)
337371
x1 = min(rx, ty)
338372

339-
gammaL = torch.zeros(size=(tx, tx), dtype=self.cdtype, device=self.device)
340-
gammaL[:x0, :x0] += self.lead_L.gamma[:x0, :x0]
341-
gammaR = torch.zeros(size=(ty, ty), dtype=self.cdtype, device=self.device)
342-
gammaR[-x1:, -x1:] += self.lead_R.gamma[-x1:, -x1:]
373+
gL_shape = (g_trans.shape[0], tx, tx) if batched else (tx, tx)
374+
gR_shape = (g_trans.shape[0], ty, ty) if batched else (ty, ty)
375+
gammaL = torch.zeros(size=gL_shape, dtype=self.cdtype, device=self.device)
376+
gammaR = torch.zeros(size=gR_shape, dtype=self.cdtype, device=self.device)
377+
if batched:
378+
gammaL[:, :x0, :x0] = gammaL[:, :x0, :x0] + gammaL_full[:, :x0, :x0]
379+
gammaR[:, -x1:, -x1:] = gammaR[:, -x1:, -x1:] + gammaR_full[:, -x1:, -x1:]
380+
else:
381+
gammaL[:x0, :x0] += gammaL_full[:x0, :x0]
382+
gammaR[-x1:, -x1:] += gammaR_full[-x1:, -x1:]
343383

344-
tc = torch.mm(torch.mm(gammaL, self.g_trans), torch.mm(gammaR, self.g_trans.conj().T)).diag().real.sum(-1)
384+
tc = (gammaL @ g_trans @ gammaR @ g_trans.mH).diagonal(dim1=-2, dim2=-1).real.sum(-1)
345385

346386
return tc
347387

@@ -368,7 +408,7 @@ def _cal_dos_(self):
368408
temp = self.grd[jj] @ self.sd[jj] + self.grl[jj-1] @ self.su[jj-1]
369409
else:
370410
temp = self.grd[jj] @ self.sd[jj] + self.grl[jj-1] @ self.su[jj-1] + self.gru[jj] @ self.sl[jj]
371-
dos -= temp.imag.diag().sum(-1) / pi
411+
dos -= temp.imag.diagonal(dim1=-2, dim2=-1).sum(-1) / pi
372412
return dos * 2
373413

374414
def _cal_ldos_(self):
@@ -391,27 +431,28 @@ def _cal_ldos_(self):
391431
temp = self.grd[jj] @ self.sd[jj] + self.grl[jj-1] @ self.su[jj-1]
392432
else:
393433
temp = self.grd[jj] @ self.sd[jj] + self.grl[jj-1] @ self.su[jj-1] + self.gru[jj] @ self.sl[jj]
394-
ldos.append(-temp.imag.diag() / pi) # shape(Nd(diagonal elements))
434+
ldos.append(-temp.imag.diagonal(dim1=-2, dim2=-1) / pi) # [n_q] or [B, n_q]
395435

396-
ldos = torch.cat(ldos, dim=0).contiguous()
436+
ldos = torch.cat(ldos, dim=-1).contiguous()
397437

398438
norbs = [0]+self.norbs_per_atom
399439
accmap = np.cumsum(norbs)
400-
ldos = torch.stack([ldos[accmap[i]:accmap[i+1]].sum() for i in range(len(accmap)-1)])
440+
ldos = torch.stack([ldos[..., accmap[i]:accmap[i+1]].sum(-1) for i in range(len(accmap)-1)], dim=-1)
401441

402442
# return ldos*2
403443
return ldos*2
404444

405445
def _cal_local_current_(self):
406-
'''calculate the local current between different atoms
446+
'''calculate the local current between different atoms
407447
408448
At this stage, local current calculation only support non-block-triagonal format Hamiltonian
409-
449+
410450
Returns
411451
-------
412452
the local current
413-
453+
414454
'''
455+
# TODO(batched-energy): vectorize then batch — currently expects scalar-E gnd[0] (2-D).
415456
# current only support non-block-triagonal format
416457
v_L = self.lead_L.voltage
417458
v_R = self.lead_R.voltage

dpnegf/negf/lead_property.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def sigmaLR2Gamma(self, se):
450450
The Gamma function, Gamma = 1j(se-se^dagger).
451451
452452
'''
453-
return 1j * (se - se.conj().T)
453+
return 1j * (se - se.mH)
454454

455455
def fermi_dirac(self, x) -> torch.Tensor:
456456
return 1 / (1 + torch.exp((x - self.chemiPot_lead)/ self.kBT))

0 commit comments

Comments
 (0)