Skip to content

Commit e0f40f7

Browse files
committed
Adding a test for class RandomStructureTransformation; change int to round in apply_transformation() of class RandomStructureTransformation.
1 parent 4e8fea7 commit e0f40f7

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

pymatgen/transformations/standard_transformations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def apply_transformation(self, structure: Structure, n_copies: int) -> list[Stru
517517

518518
el_list = list(subl_comp_dict.keys())
519519
el_concs = list(subl_comp_dict.values())
520-
lengths = [int(el_conc * len(subl_indices)) for el_conc in el_concs]
520+
lengths = [round(el_conc * len(subl_indices)) for el_conc in el_concs]
521521

522522
# randomly choose site indices for each element present in the sublattice
523523

tests/transformations/test_standard_transformations.py

+30
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
OxidationStateDecorationTransformation,
3030
OxidationStateRemovalTransformation,
3131
PartialRemoveSpecieTransformation,
32+
RandomStructureTransformation,
3233
PerturbStructureTransformation,
3334
PrimitiveCellTransformation,
3435
RemoveSpeciesTransformation,
@@ -320,6 +321,35 @@ def test_apply_transformations_best_first(self):
320321
assert len(trafo.apply_transformation(struct)) == 26
321322

322323

324+
class TestRandomStructureTransformation(unittest.TestCase):
325+
def test_apply_transformation(self):
326+
trafo = RandomStructureTransformation()
327+
coords = []
328+
coords.append([0, 0, 0])
329+
coords.append([0.25, 0.25, 0.25])
330+
lattice = Lattice(
331+
[
332+
[3.521253, 0.000000, 2.032996],
333+
[1.173751, 3.319869, 2.032996],
334+
[0.000000, 0.000000, 4.065993]
335+
]
336+
)
337+
338+
struct = Structure(
339+
lattice,
340+
[
341+
{"Ga3+": 0.5, "In3+": 0.5},
342+
{"As3-": 0.5, "P3-": 0.5}
343+
],
344+
coords)
345+
346+
struct.make_supercell([3, 3, 3])
347+
348+
output = trafo.apply_transformation(struct, n_copies=5)
349+
assert len(output) == 5
350+
assert isinstance(output[0], Structure)
351+
352+
323353
class TestOrderDisorderedStructureTransformation(unittest.TestCase):
324354
def test_apply_transformation(self):
325355
trafo = OrderDisorderedStructureTransformation()

0 commit comments

Comments
 (0)