Skip to content

Commit

Permalink
serializing (pickling in Python) of cif.Document and cif.Block (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
wojdyr committed Sep 21, 2024
1 parent 6f1d54f commit 79174a0
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 13 deletions.
1 change: 1 addition & 0 deletions include/gemmi/cifdoc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ struct Item {
Block frame;
};

Item() : type(ItemType::Erased) {}
explicit Item(LoopArg)
: type{ItemType::Loop}, loop{} {}
explicit Item(std::string&& t)
Expand Down
33 changes: 33 additions & 0 deletions include/gemmi/serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define GEMMI_SERIALIZE_HPP_

#include "model.hpp"
#include "cifdoc.hpp"

#define SERIALIZE(Struct, ...) \
template <typename Archive> \
Expand Down Expand Up @@ -144,6 +145,38 @@ SERIALIZE(Structure, o.name, o.cell, o.spacegroup_hm, o.models,
o.input_format, o.has_d_fraction, o.ter_status,
o.has_origx, o.origx, o.info, o.shortened_ccd_codes,
o.raw_remarks, o.resolution)


namespace cif {

SERIALIZE(Loop, o.tags, o.values)
SERIALIZE(Block, o.name, o.items)
SERIALIZE(Document, o.source, o.blocks)

template <typename Archive>
void serialize(Archive& archive, Item& o) {
archive(o.type, o.line_number);
switch (o.type) {
case ItemType::Pair:
case ItemType::Comment: new(&o.pair) cif::Pair; archive(o.pair); break;
case ItemType::Loop: new(&o.loop) cif::Loop; archive(o.loop); break;
case ItemType::Frame: new(&o.frame) cif::Block; archive(o.frame); break;
case ItemType::Erased: break;
}
}
template <typename Archive>
void serialize(Archive& archive, const Item& o) {
archive(o.type, o.line_number);
switch (o.type) {
case ItemType::Pair:
case ItemType::Comment: archive(o.pair); break;
case ItemType::Loop: archive(o.loop); break;
case ItemType::Frame: archive(o.frame); break;
case ItemType::Erased: break;
}
}

} // namespace cif
} // namespace gemmi

#undef SERIALIZE
Expand Down
5 changes: 5 additions & 0 deletions python/cif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "gemmi/read_cif.hpp" // for read_cif_gz

#include "common.h"
#include "serial.h" // for getstate, setstate
#include "make_iterator.h"
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
Expand Down Expand Up @@ -168,6 +169,8 @@ void add_cif(nb::module_& cif) {
return os.str();
}, nb::arg("mmjson")=false, nb::arg("lowercase_names")=true,
"Returns JSON representation in a string.")
.def("__getstate__", &getstate<Document>)
.def("__setstate__", &setstate<Document>)
.def("__repr__", [](const Document &d) {
std::string s = "<gemmi.cif.Document with ";
s += std::to_string(d.blocks.size());
Expand Down Expand Up @@ -297,6 +300,8 @@ void add_cif(nb::module_& cif) {
write_cif_block_to_stream(os, self, opt);
return os.str();
}, nb::arg("options")=WriteOptions(), "Returns a string in CIF format.")
.def("__getstate__", &getstate<Block>)
.def("__setstate__", &setstate<Block>)
.def("__repr__", [](const Block &self) {
return gemmi::cat("<gemmi.cif.Block ", self.name, '>');
});
Expand Down
24 changes: 11 additions & 13 deletions tests/test_cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import gc
import os
import pickle
import unittest
from gemmi import cif

class TestDoc(unittest.TestCase):
def test_slice(self):
def test_slicing_and_pickling(self):
doc = cif.read_string("""
data_a
_one 1 _two 2 _three 3
Expand All @@ -15,23 +16,16 @@ def test_slice(self):
data_c
_two 2 _four 4 _six 6
""")
self.assertTrue('a' in doc)
self.assertFalse('d' in doc)
self.assertEqual([b.name for b in doc[:1]], ['a'])
self.assertEqual([b.name for b in doc[1:]], ['b', 'c'])
self.assertEqual([b.name for b in doc[:]], ['a', 'b', 'c'])
self.assertEqual([b.name for b in doc[1:-1]], ['b'])
self.assertEqual([b.name for b in doc[1:1]], [])

def test_contains(self):
doc = cif.read_string("""
data_a
_one 1 _two 2 _three 3
data_b
_four 4
data_c
_two 2 _four 4 _six 6
""")
self.assertEqual('a' in doc, True)
self.assertEqual('d' in doc, False)
unpickled = pickle.loads(pickle.dumps(doc))
self.assertEqual(unpickled.as_string(), doc.as_string())

class TestBlock(unittest.TestCase):
def test_find(self):
Expand Down Expand Up @@ -358,7 +352,11 @@ def test_case_sensitivity(self):
_One 1 _thrEE 3
_NonLoop_a alpha
loop_ _lbBb _ln B 1 D 2"""
self.assertEqual(block.as_string().split(), expected.split())
block_str = block.as_string()
self.assertEqual(block_str.split(), expected.split())

unpickled = pickle.loads(pickle.dumps(block))
self.assertEqual(unpickled.as_string(), block_str)

class TestQuote(unittest.TestCase):
def test_quote(self):
Expand Down

0 comments on commit 79174a0

Please sign in to comment.