Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions src/pymatgen/io/lobster/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,24 +405,35 @@ def __init__(
# and we don't need the header.
if self._icohpcollection is None:
with zopen(self._filename, mode="rt", encoding="utf-8") as file:
all_lines = file.read().split("\n")
lines = all_lines[1:-1] if "spin" not in all_lines[1] else all_lines[2:-1]
if len(lines) == 0:
raise RuntimeError("ICOHPLIST file contains no data.")

# Determine LOBSTER version
if len(lines[0].split()) == 8 and "spin" not in all_lines[1]:
version = "3.1.1"
elif (len(lines[0].split()) == 8 or len(lines[0].split()) == 9) and "spin" in all_lines[1]:
version = "5.1.0"
elif len(lines[0].split()) == 6:
version = "2.2.1"
warnings.warn(
"Please consider using a newer LOBSTER version. See www.cohp.de.",
stacklevel=2,
)
else:
raise ValueError("Unsupported LOBSTER version.")
all_lines = file.read().splitlines()

# strip *trailing* blank lines only
all_lines = [line for line in all_lines if line.strip()]
# --- detect header length robustly ---
header_len = 0
try:
int(all_lines[0].split()[0])
except ValueError:
header_len += 1
if header_len < len(all_lines) and "spin" in all_lines[header_len].lower():
header_len += 1
lines = all_lines[header_len:]
if not lines:
raise RuntimeError("ICOHPLIST file contains no data.")
# --- version by column count only ---
ncol = len(lines[0].split())
if ncol == 6:
version = "2.2.1"
warnings.warn(
"Please consider using a newer LOBSTER version. See www.cohp.de.",
stacklevel=2,
)
elif ncol == 8:
version = "3.1.1"
elif ncol == 9:
version = "5.1.0"
else:
raise ValueError(f"Unsupported LOBSTER version ({ncol} columns).")

# If the calculation is spin polarized, the line in the middle
# of the file will be another header line.
Expand Down Expand Up @@ -587,6 +598,10 @@ def icohplist(self) -> dict[Any, dict[str, Any]]:
"translation": value._translation,
"orbitals": value._orbitals,
}

# for LCFO only files drop the single orbital resolved entry when not in orbitalwise mode
if self.is_lcfo and not self.orbitalwise:
icohp_dict = {k: d for k, d in icohp_dict.items() if d.get("orbitals") is None}
return icohp_dict

@property
Expand Down Expand Up @@ -1720,7 +1735,12 @@ def has_good_quality_check_occupied_bands(
raise ValueError("number_occ_bands_spin_down has to be specified")

for spin in (Spin.up, Spin.down) if spin_polarized else (Spin.up,):
num_occ_bands = number_occ_bands_spin_up if spin is Spin.up else number_occ_bands_spin_down
if spin is Spin.up:
num_occ_bands = number_occ_bands_spin_up
else:
if number_occ_bands_spin_down is None:
raise ValueError("number_occ_bands_spin_down has to be specified")
num_occ_bands = number_occ_bands_spin_down

for overlap_matrix in self.band_overlaps_dict[spin]["matrices"]:
sub_array = np.asarray(overlap_matrix)[:num_occ_bands, :num_occ_bands]
Expand Down Expand Up @@ -2333,7 +2353,7 @@ def _parse_matrix(
file_data: list[str],
pattern: str,
e_fermi: float,
) -> tuple[list[float], dict, dict]:
) -> tuple[list[np.ndarray], dict[Any, Any], dict[Any, Any]]:
complex_matrices: dict = {}
matrix_diagonal_values = []
start_inxs_real = []
Expand Down
20 changes: 19 additions & 1 deletion tests/io/lobster/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import json
import os
import tempfile

import numpy as np
import pytest
Expand Down Expand Up @@ -1955,7 +1956,7 @@ def test_attributes(self):
assert self.icohp_lcfo.is_spin_polarized
assert len(self.icohp_lcfo.icohplist) == 28
assert not self.icohp_lcfo_non_orbitalwise.orbitalwise
assert len(self.icohp_lcfo_non_orbitalwise.icohplist) == 27
assert len(self.icohp_lcfo_non_orbitalwise.icohplist) == 28

def test_values(self):
icohplist_bise = {
Expand Down Expand Up @@ -2171,6 +2172,23 @@ def test_msonable(self):
else:
assert getattr(icohplist_from_dict, attr_name) == attr_value

def test_missing_trailing_newline(self):
content = (
"1 Co1 O1 1.00000 0 0 0 -0.50000 -1.00000\n"
"2 Co2 O2 1.10000 0 0 0 -0.60000 -1.10000"
)

with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
tmp.write(content)
tmp.flush()
fname = tmp.name
try:
ip = Icohplist(filename=fname)
assert len(ip.icohplist) == 2
assert ip.icohplist["1"]["icohp"][Spin.up] == approx(-0.5)
finally:
os.remove(fname)


class TestNciCobiList:
def setup_method(self):
Expand Down