Skip to content

Commit e16aa44

Browse files
elainethaleclaude
andcommitted
feat: gams.transfer write backend (Phase B)
TransferBackend.write_file builds a gams.transfer Container from the loaded gdxpds symbols (the inverse of the Phase A read translation) and writes it: - value columns derive from value_col_names (GamsValueType), the same source the gdxcc backend uses -- no second hard-coded list to keep in sync - inverse special-value mapping (eps -> EPS, NaN -> NA; +/-inf unchanged) - empty element_text for Sets (strict parity with gdxcc: no set-text-write in v2.1.0; the membership-boolean wart collapses sets to all-False) - strict/relaxed domain choice mirrors the gdxcc write path - writing aliases is unsupported (to_gdx never infers one); raises clearly Tests: write parity now covers the full write x read backend matrix (both engines write, both read each output) including transfer-write/transfer-read; plus the R12 mixed-boolean set-write gate and an alias-write NotImplementedError test for the one write branch parity cannot reach. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 33ce9d2 commit e16aa44

2 files changed

Lines changed: 183 additions & 22 deletions

File tree

src/gdxpds/_transfer_backend.py

Lines changed: 114 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""gams.transfer implementation of :class:`gdxpds._backend.GdxBackend` (read).
22
3-
Phase A: the read fast path. ``open_read`` builds the symbol metadata from a
4-
``gams.transfer`` Container (records-free), and ``load_symbols`` reads records
5-
(bulk or targeted) and translates each symbol into the gdxpds DataFrame shape so
6-
the result matches the gdxcc backend. ``write_file`` is not implemented yet
7-
(Phase B) and inherits the ABC default that raises.
3+
Read (Phase A): ``open_read`` builds the symbol metadata from a ``gams.transfer``
4+
Container (records-free), and ``load_symbols`` reads records (bulk or targeted)
5+
and translates each symbol into the gdxpds DataFrame shape so the result matches
6+
the gdxcc backend. Write (Phase B): ``write_file`` builds a Container from the
7+
gdxpds symbols (the inverse translation) and writes it. Writing aliases is not
8+
supported (to_gdx never infers one); use ``backend='gdxcc'`` for that.
89
910
``gams.transfer`` is imported at module load, but this module is itself imported
1011
lazily by :func:`gdxpds._backend.make_backend`, so ``import gdxpds`` stays free
@@ -57,6 +58,27 @@
5758
# either, so leaving it unmapped keeps the two backends consistent.
5859
}
5960

61+
# Inverse maps for the write path (gdxpds enum -> gams.transfer .type string).
62+
_VAR_TYPE_STR = {member: s for s, member in _VAR_TYPE.items()}
63+
_EQU_TYPE_STR = {member: s for s, member in _EQU_TYPE.items()}
64+
65+
66+
def _np_to_transfer_specials(records: pd.DataFrame, value_cols: list[str]) -> None:
67+
"""In place, map gdxpds canonical special values to gams.transfer encodings.
68+
69+
Inverse of :func:`_convert_transfer_specials`: machine eps -> EPS (gt's
70+
``-0.0``); NaN -> NA (gt's NA sentinel); +/-inf already match. Genuine 0.0 is
71+
left alone (only eps maps to EPS).
72+
"""
73+
eps = NUMPY_SPECIAL_VALUES[-1]
74+
for col in value_cols:
75+
arr = records[col].to_numpy(dtype="float64", copy=True)
76+
is_eps = np.abs(arr - eps) < eps
77+
is_nan = np.isnan(arr)
78+
arr[is_nan] = gt.SpecialValues.NA
79+
arr[is_eps] = gt.SpecialValues.EPS
80+
records[col] = arr
81+
6082

6183
def _data_type_of(gt_sym) -> GamsDataType:
6284
# UniverseAlias / Alias before Set (an alias is not a Set, but check the
@@ -89,12 +111,9 @@ def _convert_transfer_specials(values: pd.DataFrame) -> pd.DataFrame:
89111
eps = NUMPY_SPECIAL_VALUES[-1]
90112
out = values.copy()
91113
for col in out.columns:
92-
arr = np.asarray(out[col].to_numpy(dtype="float64"))
114+
arr = out[col].to_numpy(dtype="float64", copy=True)
93115
is_eps = np.asarray(gt.SpecialValues.isEps(arr))
94-
is_nan = np.asarray(gt.SpecialValues.isNA(arr)) | np.asarray(
95-
gt.SpecialValues.isUndef(arr)
96-
)
97-
arr = arr.copy()
116+
is_nan = np.asarray(gt.SpecialValues.isNA(arr)) | np.asarray(gt.SpecialValues.isUndef(arr))
98117
arr[is_nan] = np.nan
99118
arr[is_eps] = eps
100119
out[col] = arr
@@ -121,10 +140,90 @@ def close(self) -> None:
121140
self._container = None
122141

123142
def write_file(self, gdx_file: GdxFile, filename: str | os.PathLike[str]) -> None:
124-
raise NotImplementedError(
125-
"Writing via the gams_transfer backend is not yet implemented "
126-
"(planned for v2.1.0 Phase B); use backend='gdxcc' to write."
127-
)
143+
for symbol in gdx_file:
144+
if not symbol.loaded:
145+
raise Error("All symbols must be loaded before this file can be written.")
146+
147+
container = gt.Container(system_directory=self.gams_dir)
148+
# {name: position} for the per-symbol strict-domain eligibility check,
149+
# mirroring the gdxcc write path.
150+
name_positions = {name: i for i, name in enumerate(gdx_file._symbols.keys())}
151+
for symbol in gdx_file:
152+
self._add_symbol(container, symbol, name_positions)
153+
try:
154+
container.write(str(filename))
155+
except Exception as e:
156+
raise Error(f"gams.transfer failed to write {filename!r}: {e}")
157+
gdx_file._filename = filename
158+
159+
def _gt_domain(self, container, symbol: GdxSymbol, name_positions: dict):
160+
"""Domain spec for a gt symbol, mirroring the gdxcc strict/relaxed choice.
161+
162+
Strict (a same-file parent that precedes this symbol) -> the gt.Set refs
163+
already in the container; otherwise the dim-name strings (relaxed / '*').
164+
"""
165+
if symbol.num_dims == 0:
166+
return []
167+
if symbol._strict_domain_writeable(name_positions):
168+
return [container.data[d.name] if d is not None else "*" for d in symbol._domain]
169+
return list(symbol.dims)
170+
171+
def _add_symbol(self, container, symbol: GdxSymbol, name_positions: dict) -> None:
172+
data_type = symbol.data_type
173+
if data_type == GamsDataType.Alias:
174+
# to_gdx never infers an Alias, and writing one needs alias_with
175+
# plumbing; out of scope for v2.1.0 (use backend='gdxcc').
176+
raise NotImplementedError(
177+
"Writing aliases via the gams_transfer backend is not supported."
178+
)
179+
180+
num_dims = symbol.num_dims
181+
domain = self._gt_domain(container, symbol, name_positions)
182+
description = symbol.description or ""
183+
# Domain columns are matched positionally by gams.transfer, so give them
184+
# unique throwaway names (dodging duplicate '*' labels); value columns
185+
# are matched by name.
186+
dim_names = [f"_d{i}" for i in range(num_dims)]
187+
188+
if data_type == GamsDataType.Set:
189+
records = symbol.dataframe.iloc[:, :num_dims].copy()
190+
records.columns = dim_names
191+
records["element_text"] = "" # v2.1.0: no set-text-write (parity with gdxcc)
192+
gt.Set(container, symbol.name, domain=domain, description=description, records=records)
193+
return
194+
195+
# Parameter / Variable / Equation. gams.transfer's value-column names are
196+
# the gdxpds value_col_names lowercased (Value -> value, Level -> level,
197+
# ...); value_col_names derives from GamsValueType, the same source the
198+
# gdxcc backend uses, so there is no second hard-coded list to keep in sync.
199+
value_cols = [name.lower() for name in symbol.value_col_names]
200+
records = symbol.dataframe.copy()
201+
records.columns = dim_names + value_cols
202+
_np_to_transfer_specials(records, value_cols)
203+
if data_type == GamsDataType.Parameter:
204+
gt.Parameter(
205+
container, symbol.name, domain=domain, description=description, records=records
206+
)
207+
elif data_type == GamsDataType.Variable:
208+
vt = symbol.variable_type
209+
gt.Variable(
210+
container,
211+
symbol.name,
212+
_VAR_TYPE_STR.get(vt, "free") if vt is not None else "free",
213+
domain=domain,
214+
description=description,
215+
records=records,
216+
)
217+
else: # Equation
218+
et = symbol.equation_type
219+
gt.Equation(
220+
container,
221+
symbol.name,
222+
_EQU_TYPE_STR.get(et, "eq") if et is not None else "eq",
223+
domain=domain,
224+
description=description,
225+
records=records,
226+
)
128227

129228
def open_read(self, gdx_file: GdxFile, filename: str | os.PathLike[str]) -> None:
130229
# Metadata only: keeps list_symbols / get_data_types cheap. Records are
@@ -176,9 +275,7 @@ def load_symbols(
176275
else:
177276
# Targeted: read just the requested symbols' records.
178277
targets = [s for s in symbols if not s.loaded]
179-
container = (
180-
self._read_records(gdx_file, [s.name for s in targets]) if targets else None
181-
)
278+
container = self._read_records(gdx_file, [s.name for s in targets]) if targets else None
182279
for symbol in targets:
183280
self._translate(container, symbol, load_set_text=load_set_text)
184281

tests/test_backend_parity.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,15 @@
1414

1515
import gdxpds
1616
from gdxpds import to_dataframes, to_gdx
17+
from gdxpds.gdx import GamsDataType, GdxFile
1718

18-
pytestmark = pytest.mark.skipif(
19-
not gdxpds.HAVE_GAMS_TRANSFER, reason="gams.transfer not available"
20-
)
19+
pytestmark = pytest.mark.skipif(not gdxpds.HAVE_GAMS_TRANSFER, reason="gams.transfer not available")
2120

2221
# Computed at import (collection) time because @parametrize needs the values
2322
# before the conftest ``data_dir`` fixture is available. Test bodies use the
2423
# ``data_dir`` fixture (repo convention); this constant just feeds parametrize.
2524
FIXTURES = sorted(
26-
os.path.basename(p)
27-
for p in glob.glob(os.path.join(os.path.dirname(__file__), "data", "*.gdx"))
25+
os.path.basename(p) for p in glob.glob(os.path.join(os.path.dirname(__file__), "data", "*.gdx"))
2826
)
2927

3028

@@ -80,3 +78,69 @@ def test_read_parity_symbol_subset(data_dir):
8078
b = to_dataframes(path, backend="gams_transfer", symbols=names)
8179
assert list(a) == names and list(b) == names
8280
_assert_same(a, b)
81+
82+
83+
# --- Phase B: write parity, over the full write x read backend matrix ---
84+
85+
_BACKENDS = ("gdxcc", "gams_transfer")
86+
87+
88+
def _write_read_matrix(dfs, tmp_path, **kw):
89+
"""Write ``dfs`` with each backend, then read each output back with each
90+
backend. Returns ``{(write_backend, read_backend): dataframes}``, exercising
91+
the full 2x2 matrix -- including the cross-engine combinations (notably
92+
gams_transfer-write -> gams_transfer-read, the fast path's real workflow,
93+
which reading-back-only-via-gdxcc would never touch)."""
94+
paths = {}
95+
for w in _BACKENDS:
96+
paths[w] = str(tmp_path / f"via_{w}.gdx")
97+
to_gdx(dfs, paths[w], backend=w, **kw)
98+
return {(w, r): to_dataframes(paths[w], backend=r) for w in _BACKENDS for r in _BACKENDS}
99+
100+
101+
def _assert_matrix_consistent(matrix):
102+
# gdxcc-write + gdxcc-read is the oracle (the legacy round-trip); every other
103+
# (write, read) combination must reproduce it.
104+
oracle = matrix[("gdxcc", "gdxcc")]
105+
for dfs in matrix.values():
106+
_assert_same(oracle, dfs)
107+
108+
109+
@pytest.mark.parametrize("fixture", FIXTURES)
110+
def test_write_parity(data_dir, fixture, tmp_path):
111+
dfs = to_dataframes(os.path.join(data_dir, fixture), backend="gdxcc")
112+
_assert_matrix_consistent(_write_read_matrix(dfs, tmp_path))
113+
114+
115+
def test_write_parity_special_values(tmp_path):
116+
eps = np.finfo(float).eps
117+
dfs = {
118+
"p": pd.DataFrame(
119+
{"i": ["a", "b", "c", "d", "e"], "Value": [np.nan, np.inf, -np.inf, eps, 0.0]}
120+
),
121+
"scalar": pd.DataFrame({"Value": [42.0]}),
122+
}
123+
_assert_matrix_consistent(_write_read_matrix(dfs, tmp_path))
124+
125+
126+
def test_write_parity_mixed_boolean_set(tmp_path):
127+
# R12 gate: gdxcc collapses every set element to 0.0 / c_bool(False) on write
128+
# (the membership-boolean wart), so a Set with mixed True/False must read back
129+
# all-False no matter which engine wrote *or* read it.
130+
dfs = {"s": pd.DataFrame({"i": ["a", "b", "c"], "Value": [True, False, True]})}
131+
matrix = _write_read_matrix(dfs, tmp_path)
132+
_assert_matrix_consistent(matrix)
133+
for dfs in matrix.values():
134+
assert [bool(v) for v in dfs["s"]["Value"]] == [False, False, False]
135+
136+
137+
def test_write_alias_unsupported(data_dir, tmp_path):
138+
# to_gdx never infers an Alias, so the parity tests never reach the write
139+
# path's Alias branch. A GdxFile read from an alias-bearing GDX *does* carry
140+
# an Alias symbol; writing it via gams_transfer is explicitly unsupported in
141+
# v2.1.0 (use backend='gdxcc'). Lock in the NotImplementedError contract.
142+
f = GdxFile(lazy_load=False, backend="gams_transfer")
143+
f.read(os.path.join(data_dir, "alias_fixture.gdx"))
144+
assert any(s.data_type == GamsDataType.Alias for s in f)
145+
with pytest.raises(NotImplementedError):
146+
f.write(str(tmp_path / "out.gdx"))

0 commit comments

Comments
 (0)