1616"""
1717log = logging .getLogger (__name__ )
1818
19+
20+ def _build_s_in_batched (hd , seinL , seinR , idx0 , idy0 , idx1 , idy1 ):
21+ '''Allocate per-block [B, n_q, n_q] zeros and inject the corner self-energy slices.
22+
23+ Mirrors the scalar construction in cal_green_function (the seinL/seinR fed in
24+ here are the already-scaled 1j*(seL - seL.mH) * f tensors of shape [B,n,n]).
25+ '''
26+ B = seinL .shape [0 ]
27+ s_in = [torch .zeros ((B ,) + tuple (blk .shape ), dtype = torch .complex128 ) for blk in hd ]
28+ s_in [0 ][:, :idx0 , :idy0 ] = s_in [0 ][:, :idx0 , :idy0 ] + seinL [:, :idx0 , :idy0 ]
29+ s_in [- 1 ][:, - idx1 :, - idy1 :] = s_in [- 1 ][:, - idx1 :, - idy1 :] + seinR [:, - idx1 :, - idy1 :]
30+ return s_in
31+
32+
1933class DeviceProperty (object ):
2034 '''Device object for NEGF calculation
2135
@@ -150,8 +164,12 @@ def cal_green_function(self, energy, kpoint, eta_device=0., block_tridiagonal=Tr
150164 A boolean parameter that indicates whether the last column blocks of the retarded Green's function are needed.
151165 '''
152166 assert len (np .array (kpoint ).reshape (- 1 )) == 3
153- if not isinstance (energy , torch .Tensor ):
154- energy = torch .tensor (energy , dtype = torch .complex128 )
167+ energy = torch .as_tensor (energy , dtype = torch .complex128 )
168+ if energy .ndim == 0 :
169+ energy = energy .reshape (1 )
170+ assert energy .ndim == 1 , f"energy must be 0-d, scalar, or 1-D [B]; got shape { tuple (energy .shape )} "
171+ B = energy .shape [0 ]
172+ batched_mode = B > 1
155173
156174 self .block_tridiagonal = block_tridiagonal
157175 if self .kpoint is None or abs (self .kpoint - torch .tensor (kpoint )).sum () > 1e-5 :
@@ -200,33 +218,45 @@ def cal_green_function(self, energy, kpoint, eta_device=0., block_tridiagonal=Tr
200218
201219 seL = self .lead_L .se
202220 seR = self .lead_R .se
221+ if batched_mode :
222+ assert seL .ndim == 3 and seR .ndim == 3 , f"In batched mode, the self-energy should have shape [B,n,n], but got { seL .shape } and { seR .shape } "
223+ else :
224+ assert seL .ndim == 2 and seR .ndim == 2 , f"In non-batched mode, the self-energy should have shape [n,n], but got { seL .shape } and { seR .shape } "
225+
203226 s01 , s02 = self .hd [0 ].shape # The shape of the first H block
204- se01 , se02 = seL .shape # The shape of the left self-energy
227+ se01 , se02 = seL .shape [ - 2 ], seL . shape [ - 1 ] # last two dims work for [n,n] and [B,n,n]
205228 s11 , s12 = self .hd [- 1 ].shape
206- se11 , se12 = seR .shape
229+ se11 , se12 = seR .shape [ - 2 ], seR . shape [ - 1 ]
207230 idx0 , idy0 = min (s01 , se01 ), min (s02 , se02 )
208231 idx1 , idy1 = min (s11 , se11 ), min (s12 , se12 )
209232 if block_tridiagonal :
210233 # Based on the block tridiagonal algorithm, the shape of the self-energy should be
211- # equal to or larger than the corresponding Hamiltonian block
234+ # equal to or lesser than the corresponding Hamiltonian block
212235 if se01 > s01 or se02 > s02 :
213236 log .warning (f"The shape of left self-energy ({ se01 } ,{ se02 } ) is larger than\
214237 the first Hamiltonian block ({ s01 } ,{ s02 } )." )
215238 raise ValueError ("Left Lead Self Energy size is larger than the first Hamiltonian Block." )
216239 if se11 > s11 or se12 > s12 :
217- log .warning (f"The shape of right self-energy ({ se11 } ,{ se12 } ) is different from \
240+ log .warning (f"The shape of right self-energy ({ se11 } ,{ se12 } ) is larger than \
218241 the last Hamiltonian block ({ s11 } ,{ s12 } )." )
219242 raise ValueError ("Right Lead Self Energy size is larger than the last Hamiltonian Block." )
220243
221244 green_funcs = {}
222245
223246 if need_lesser :
224247 # Fluctuation-Dissipation theorem; only build s_in when the lesser GF is consumed
225- seinL = 1j * (seL - seL .conj ().T ) * self .lead_L .fermi_dirac (energy + self .E_ref ).reshape (- 1 )
226- seinR = 1j * (seR - seR .conj ().T ) * self .lead_R .fermi_dirac (energy + self .E_ref ).reshape (- 1 )
227- s_in = [torch .zeros (i .shape ).cdouble () for i in self .hd ]
228- s_in [0 ][:idx0 ,:idy0 ] = s_in [0 ][:idx0 ,:idy0 ] + seinL [:idx0 ,:idy0 ]
229- s_in [- 1 ][- idx1 :,- idy1 :] = s_in [- 1 ][- idx1 :,- idy1 :] + seinR [- idx1 :,- idy1 :]
248+ if batched_mode :
249+ fL = self .lead_L .fermi_dirac (energy + self .E_ref ).reshape (B , 1 , 1 )
250+ fR = self .lead_R .fermi_dirac (energy + self .E_ref ).reshape (B , 1 , 1 )
251+ seinL = 1j * (seL - seL .mH ) * fL
252+ seinR = 1j * (seR - seR .mH ) * fR
253+ s_in = _build_s_in_batched (self .hd , seinL , seinR , idx0 , idy0 , idx1 , idy1 )
254+ else :
255+ seinL = 1j * (seL - seL .conj ().T ) * self .lead_L .fermi_dirac (energy + self .E_ref ).reshape (- 1 )
256+ seinR = 1j * (seR - seR .conj ().T ) * self .lead_R .fermi_dirac (energy + self .E_ref ).reshape (- 1 )
257+ s_in = [torch .zeros (i .shape ).cdouble () for i in self .hd ]
258+ s_in [0 ][:idx0 ,:idy0 ] = s_in [0 ][:idx0 ,:idy0 ] + seinL [:idx0 ,:idy0 ]
259+ s_in [- 1 ][- idx1 :,- idy1 :] = s_in [- 1 ][- idx1 :,- idy1 :] + seinR [- idx1 :,- idy1 :]
230260 else :
231261 s_in = 0
232262
@@ -322,26 +352,36 @@ def _cal_current_nscf_(self, energy_grid, tc):
322352
323353
324354 def _cal_tc_ (self ):
325- '''calculate the transmission coefficient
326-
355+ '''calculate the transmission coefficient
356+
327357 Returns
328358 -------
329- tc is the transmission coefficient
330-
359+ tc is the transmission coefficient
360+
331361 '''
332362
333- tx , ty = self .g_trans .shape
334- lx , ly = self .lead_L .se .shape
335- rx , ry = self .lead_R .se .shape
363+ g_trans = self .g_trans
364+ batched = g_trans .ndim == 3
365+ tx , ty = g_trans .shape [- 2 ], g_trans .shape [- 1 ]
366+ gammaL_full = self .lead_L .gamma
367+ gammaR_full = self .lead_R .gamma
368+ lx = gammaL_full .shape [- 2 ]
369+ rx = gammaR_full .shape [- 2 ]
336370 x0 = min (lx , tx )
337371 x1 = min (rx , ty )
338372
339- gammaL = torch .zeros (size = (tx , tx ), dtype = self .cdtype , device = self .device )
340- gammaL [:x0 , :x0 ] += self .lead_L .gamma [:x0 , :x0 ]
341- gammaR = torch .zeros (size = (ty , ty ), dtype = self .cdtype , device = self .device )
342- gammaR [- x1 :, - x1 :] += self .lead_R .gamma [- x1 :, - x1 :]
373+ gL_shape = (g_trans .shape [0 ], tx , tx ) if batched else (tx , tx )
374+ gR_shape = (g_trans .shape [0 ], ty , ty ) if batched else (ty , ty )
375+ gammaL = torch .zeros (size = gL_shape , dtype = self .cdtype , device = self .device )
376+ gammaR = torch .zeros (size = gR_shape , dtype = self .cdtype , device = self .device )
377+ if batched :
378+ gammaL [:, :x0 , :x0 ] = gammaL [:, :x0 , :x0 ] + gammaL_full [:, :x0 , :x0 ]
379+ gammaR [:, - x1 :, - x1 :] = gammaR [:, - x1 :, - x1 :] + gammaR_full [:, - x1 :, - x1 :]
380+ else :
381+ gammaL [:x0 , :x0 ] += gammaL_full [:x0 , :x0 ]
382+ gammaR [- x1 :, - x1 :] += gammaR_full [- x1 :, - x1 :]
343383
344- tc = torch . mm ( torch . mm ( gammaL , self . g_trans ), torch . mm ( gammaR , self . g_trans .conj (). T )). diag ( ).real .sum (- 1 )
384+ tc = ( gammaL @ g_trans @ gammaR @ g_trans .mH ). diagonal ( dim1 = - 2 , dim2 = - 1 ).real .sum (- 1 )
345385
346386 return tc
347387
@@ -368,7 +408,7 @@ def _cal_dos_(self):
368408 temp = self .grd [jj ] @ self .sd [jj ] + self .grl [jj - 1 ] @ self .su [jj - 1 ]
369409 else :
370410 temp = self .grd [jj ] @ self .sd [jj ] + self .grl [jj - 1 ] @ self .su [jj - 1 ] + self .gru [jj ] @ self .sl [jj ]
371- dos -= temp .imag .diag ( ).sum (- 1 ) / pi
411+ dos -= temp .imag .diagonal ( dim1 = - 2 , dim2 = - 1 ).sum (- 1 ) / pi
372412 return dos * 2
373413
374414 def _cal_ldos_ (self ):
@@ -391,27 +431,28 @@ def _cal_ldos_(self):
391431 temp = self .grd [jj ] @ self .sd [jj ] + self .grl [jj - 1 ] @ self .su [jj - 1 ]
392432 else :
393433 temp = self .grd [jj ] @ self .sd [jj ] + self .grl [jj - 1 ] @ self .su [jj - 1 ] + self .gru [jj ] @ self .sl [jj ]
394- ldos .append (- temp .imag .diag ( ) / pi ) # shape(Nd(diagonal elements))
434+ ldos .append (- temp .imag .diagonal ( dim1 = - 2 , dim2 = - 1 ) / pi ) # [n_q] or [B, n_q]
395435
396- ldos = torch .cat (ldos , dim = 0 ).contiguous ()
436+ ldos = torch .cat (ldos , dim = - 1 ).contiguous ()
397437
398438 norbs = [0 ]+ self .norbs_per_atom
399439 accmap = np .cumsum (norbs )
400- ldos = torch .stack ([ldos [accmap [i ]:accmap [i + 1 ]].sum () for i in range (len (accmap )- 1 )])
440+ ldos = torch .stack ([ldos [..., accmap [i ]:accmap [i + 1 ]].sum (- 1 ) for i in range (len (accmap )- 1 )], dim = - 1 )
401441
402442 # return ldos*2
403443 return ldos * 2
404444
405445 def _cal_local_current_ (self ):
406- '''calculate the local current between different atoms
446+ '''calculate the local current between different atoms
407447
408448 At this stage, local current calculation only support non-block-triagonal format Hamiltonian
409-
449+
410450 Returns
411451 -------
412452 the local current
413-
453+
414454 '''
455+ # TODO(batched-energy): vectorize then batch — currently expects scalar-E gnd[0] (2-D).
415456 # current only support non-block-triagonal format
416457 v_L = self .lead_L .voltage
417458 v_R = self .lead_R .voltage
0 commit comments