diff --git a/doc/changes/devel/13165.bugfix.rst b/doc/changes/devel/13165.bugfix.rst new file mode 100644 index 00000000000..bbff6a554fd --- /dev/null +++ b/doc/changes/devel/13165.bugfix.rst @@ -0,0 +1 @@ +Fixes behavior of stc.save() filetypes and suffixes while saving through SourceEstimate.save method in mne.source_estimate by `Shresth Keshari`_. \ No newline at end of file diff --git a/mne/source_estimate.py b/mne/source_estimate.py index deeb3a43ede..420b1e36fa6 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -1884,7 +1884,7 @@ class SourceEstimate(_BaseSurfaceSourceEstimate): """ @verbose - def save(self, fname, ftype="stc", *, overwrite=False, verbose=None): + def save(self, fname, ftype="auto", *, overwrite=False, verbose=None): """Save the source estimates to a file. Parameters @@ -1894,18 +1894,29 @@ def save(self, fname, ftype="stc", *, overwrite=False, verbose=None): spaces are obtained by adding ``"-lh.stc"`` and ``"-rh.stc"`` (or ``"-lh.w"`` and ``"-rh.w"``) to the stem provided, for the left and the right hemisphere, respectively. - ftype : str - File format to use. Allowed values are ``"stc"`` (default), - ``"w"``, and ``"h5"``. The ``"w"`` format only supports a single - time point. + ftype : "auto" | "stc" | "w" | "h5" + File format to use. If "auto", the file format will be inferred from the + file extension if possible. Other allowed values are ``"stc"``, ``"w"``, and + ``"h5"``. The ``"w"`` format only supports a single time point. %(overwrite)s .. versionadded:: 1.0 %(verbose)s """ fname = str(_check_fname(fname=fname, overwrite=True)) # checked below + if ftype == "auto": + if fname.endswith(".stc"): + ftype = "stc" + elif fname.endswith(".w"): + ftype = "w" + elif fname.endswith(".h5"): + ftype = "h5" + else: + logger.info( + "Cannot infer file type from `fname`; falling back to `.stc` format" + ) + ftype = "stc" _check_option("ftype", ftype, ["stc", "w", "h5"]) - lh_data = self.data[: len(self.lh_vertno)] rh_data = self.data[-len(self.rh_vertno) :] @@ -1918,6 +1929,8 @@ def save(self, fname, ftype="stc", *, overwrite=False, verbose=None): "real numbers before saving." ) logger.info("Writing STC to disk...") + if fname.endswith(".stc"): + fname = fname[:-4] fname_l = str(_check_fname(fname + "-lh.stc", overwrite=overwrite)) fname_r = str(_check_fname(fname + "-rh.stc", overwrite=overwrite)) _write_stc( diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index e4fa5a36b25..e9d7a2fb43d 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -481,8 +481,7 @@ def test_io_stc(tmp_path): """Test IO for STC files.""" stc = _fake_stc() stc.save(tmp_path / "tmp.stc") - stc2 = read_source_estimate(tmp_path / "tmp.stc") - + stc2 = read_source_estimate(tmp_path / "tmp") assert_array_almost_equal(stc.data, stc2.data) assert_array_almost_equal(stc.tmin, stc2.tmin) assert_equal(len(stc.vertices), len(stc2.vertices))