Skip to content

Commit 130bf4f

Browse files
committed
Add test for tb.dat
1 parent 9686791 commit 130bf4f

File tree

7 files changed

+31075
-41
lines changed

7 files changed

+31075
-41
lines changed

examples/silicon/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
build
2+
build_tb

examples/silicon/read_tb.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
#!/usr/bin/env python
2-
3-
# (c) 2015-2018, ETH Zurich, Institut fuer Theoretische Physik
4-
# Author: Dominik Gresch <[email protected]>
5-
2+
# Construct a Model from wannier90 tb.dat file.
63
import os
74
import shutil
85
import subprocess
@@ -12,29 +9,30 @@
129

1310
if __name__ == "__main__":
1411
WANNIER90_COMMAND = os.path.expanduser("~/git/wannier90/wannier90.x")
15-
BUILD_DIR = "./build"
16-
17-
# shutil.rmtree(BUILD_DIR, ignore_errors=True)
18-
# shutil.copytree("./input", BUILD_DIR)
19-
# subprocess.call([WANNIER90_COMMAND, "silicon"], cwd=BUILD_DIR)
12+
BUILD_DIR = "./build_tb"
2013

21-
# model = tb.Model.from_wannier_folder(BUILD_DIR, prefix="silicon")
22-
# print(model)
23-
# print(model.hop[(2, 1, -3)])
14+
if not os.path.exists(BUILD_DIR):
15+
shutil.copytree("./input", BUILD_DIR)
16+
subprocess.call([WANNIER90_COMMAND, "silicon"], cwd=BUILD_DIR)
2417

25-
BUILD_DIR = "./build2"
26-
model = tb.Model.from_wannier_tb_file(
27-
tb_file=f'{BUILD_DIR}/silicon_tb.dat',
28-
wsvec_file=f'{BUILD_DIR}/silicon_wsvec.dat'
18+
model = tb.Model.from_wannier_tb_files(
19+
tb_file=f"{BUILD_DIR}/silicon_tb.dat",
20+
wsvec_file=f"{BUILD_DIR}/silicon_wsvec.dat",
2921
)
3022
print(model)
31-
# print(model.hop[(2, 1, -3)])
32-
# exit()
3323

24+
# Compute band structure along an arbitrary kpath
3425
theta = 37 / 180 * np.pi
3526
phi = 43 / 180 * np.pi
3627
rlist = np.linspace(0, 2, 20)
37-
klist = [[r*np.sin(theta)*np.cos(phi), r*np.sin(theta)*np.sin(phi), r*np.cos(theta)] for r in rlist]
28+
klist = [
29+
[
30+
r * np.sin(theta) * np.cos(phi),
31+
r * np.sin(theta) * np.sin(phi),
32+
r * np.cos(theta),
33+
]
34+
for r in rlist
35+
]
3836

3937
eigvals = model.eigenval(klist)
4038

tbmodels/_tb_model.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def _read_tb(iterator, ignore_orbital_order=False):
538538

539539
lattice = np.zeros((3, 3))
540540
for i in range(3):
541-
lattice[i, :] = np.fromstring(next(iterator), sep=' ')
541+
lattice[i, :] = np.fromstring(next(iterator), sep=" ")
542542

543543
num_wann = int(next(iterator))
544544

@@ -554,28 +554,32 @@ def _read_tb(iterator, ignore_orbital_order=False):
554554
# <0n|H|Rm>
555555
hop_list = []
556556
for ir in range(nrpts):
557-
next(iterator) # skip empty
557+
next(iterator) # skip empty
558558
r_vec = [int(_) for _ in next(iterator).strip().split()]
559559
for j in range(num_wann):
560560
for i in range(num_wann):
561561
line = next(iterator).strip().split()
562562
iw, jw = [int(_) for _ in line[:2]]
563563
if not ignore_orbital_order and (iw != i + 1 or jw != j + 1):
564-
raise ValueError(f"Inconsistent orbital numbers in line '{line}'")
564+
raise ValueError(
565+
f"Inconsistent orbital numbers in line '{line}'"
566+
)
565567
ham = (float(line[2]) + 1j * float(line[3])) / deg_pts[ir]
566568
hop_list.append([ham, i, j, r_vec])
567569

568570
# <0n|r|Rm>
569571
r_list = []
570572
for ir in range(nrpts):
571-
next(iterator) # skip empty
573+
next(iterator) # skip empty
572574
r_vec = [int(_) for _ in next(iterator).strip().split()]
573575
for j in range(num_wann):
574576
for i in range(num_wann):
575577
line = next(iterator).strip().split()
576578
iw, jw = [int(_) for _ in line[:2]]
577579
if not ignore_orbital_order and (iw != i + 1 or jw != j + 1):
578-
raise ValueError(f"Inconsistent orbital numbers in line '{line}'")
580+
raise ValueError(
581+
f"Inconsistent orbital numbers in line '{line}'"
582+
)
579583
r = np.array([float(_) for _ in line[2:]])
580584
r = r[::2] + 1j * r[1::2]
581585
r_list.append([r, i, j, r_vec])
@@ -769,7 +773,7 @@ def remap_hoppings(hop_entries):
769773
return cls.from_hop_list(size=num_wann, hop_list=hop_entries, **kwargs)
770774

771775
@classmethod # noqa: MC0001
772-
def from_wannier_tb_file( # pylint: disable=too-many-locals
776+
def from_wannier_tb_files( # pylint: disable=too-many-locals
773777
cls,
774778
*,
775779
tb_file: str,
@@ -803,19 +807,19 @@ def from_wannier_tb_file( # pylint: disable=too-many-locals
803807

804808
with open(tb_file) as f:
805809
lattice, num_wann, nrpts, deg_pts, hop_list, r_list = cls._read_tb(f)
806-
810+
807811
kwargs["uc"] = lattice
808812

809-
def get_centers(r_list: list) -> list:
810-
centers = [None for _ in range(num_wann)]
813+
def get_centers(r_list: ty.List[ty.Any]) -> ty.List[npt.NDArray[np.float_]]:
814+
centers = [np.zeros(3) for _ in range(num_wann)]
811815
for r, i, j, r_vec in r_list:
812816
if r_vec != [0, 0, 0]:
813817
continue
814818
if i != j:
815819
continue
816820
r = np.array(r)
817821
if not np.allclose(np.abs(r.imag), 0):
818-
raise ValueError(f'Center should be real: WF {i+1}, center = {r}')
822+
raise ValueError(f"Center should be real: WF {i+1}, center = {r}")
819823
centers[i] = r.real
820824
return centers
821825

@@ -829,21 +833,15 @@ def get_centers(r_list: list) -> list:
829833
hop_entries = hop_list
830834

831835
with open(wsvec_file) as f:
832-
wsvec_generator = cls._async_parse(
833-
cls._read_wsvec(f), chunksize=num_wann
834-
)
836+
wsvec_generator = cls._async_parse(cls._read_wsvec(f), chunksize=num_wann)
835837

836838
def remap_hoppings(hop_entries):
837839
for t, orbital_1, orbital_2, R in hop_entries:
838840
# Step _async_parse to where it accepts
839841
# a new key.
840842
# The _async_parse does not raise StopIteration
841-
next( # pylint: disable=stop-iteration-return
842-
wsvec_generator
843-
)
844-
T_list = wsvec_generator.send(
845-
(orbital_1, orbital_2, tuple(R))
846-
)
843+
next(wsvec_generator) # pylint: disable=stop-iteration-return
844+
T_list = wsvec_generator.send((orbital_1, orbital_2, tuple(R)))
847845
N = len(T_list)
848846
for T in T_list:
849847
# not using numpy here increases performance
@@ -856,9 +854,7 @@ def remap_hoppings(hop_entries):
856854

857855
hop_entries = remap_hoppings(hop_entries)
858856

859-
return cls.from_hop_list(
860-
size=num_wann, hop_list=hop_entries, **kwargs
861-
)
857+
return cls.from_hop_list(size=num_wann, hop_list=hop_entries, **kwargs)
862858

863859
@staticmethod
864860
def _async_parse(iterator, chunksize=1):
Binary file not shown.

0 commit comments

Comments
 (0)