diff --git a/specparam/data/data.py b/specparam/data/data.py index cd3546b4..7f0d94af 100644 --- a/specparam/data/data.py +++ b/specparam/data/data.py @@ -9,6 +9,7 @@ from specparam.data import SpectrumMetaData, ModelChecks from specparam.utils.spectral import trim_spectrum from specparam.utils.checks import check_input_options +from specparam.reports.strings import gen_data_str from specparam.modutils.errors import DataError, InconsistentDataError from specparam.modutils.docs import docs_get_section, replace_docstring_sections from specparam.plts.settings import PLT_COLORS @@ -77,6 +78,17 @@ def has_data(self): return bool(np.any(self.power_spectrum)) + @property + def n_freqs(self): + """Indicator for the number of frequency values.""" + + n_freqs = None + if self.has_data: + n_freqs = len(self.freqs) + + return n_freqs + + def add_data(self, freqs, power_spectrum, freq_range=None): """Add data (frequencies, and power spectrum values) to the current object. @@ -151,6 +163,18 @@ def plot(self, plt_log=False, **plt_kwargs): log_powers=False, **data_kwargs) + def print(self, concise=False): + """Print out a data summary. + + Parameters + ---------- + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_data_str(self, concise)) + + def set_checks(self, check_freqs=None, check_data=None): """Set check statuses, which control if an error is raised based on check on the inputs. @@ -330,6 +354,17 @@ def has_data(self): return bool(np.any(self.power_spectra)) + @property + def n_spectra(self): + """Indicator for the number of power spectra.""" + + n_spectra = None + if self.has_data: + n_spectra = len(self.power_spectra) + + return n_spectra + + def add_data(self, freqs, power_spectra, freq_range=None): """Add data (frequencies and power spectrum values) to the current object. @@ -515,6 +550,17 @@ def n_events(self): return len(self.spectrograms) + @property + def n_spectra(self): + """Redefine n_spectra marker to reflect the total number of spectra.""" + + n_spectra = None + if self.has_data: + n_spectra = self.n_events * self.n_time_windows + + return n_spectra + + def add_data(self, freqs, spectrograms, freq_range=None): """Add data (frequencies and spectrograms) to the current object. diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index 778caea6..6e1f5239 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -118,6 +118,56 @@ def gen_version_str(concise=False): return output +def gen_data_str(data, concise=False): + """Generate a string representation summarizing current data. + + Parameters + ---------- + data : Data + Data object to summarize data for. + Can also be any derived data object (e.g. Data2D). + concise : bool, optional, default: False + Whether to print the report in concise mode. + + Returns + ------- + output : str + Formatted string of data summary. + """ + + if not data.has_data: + + no_data_str = "No data currently loaded in the object." + str_lst = [DIVIDER,'', no_data_str, '', DIVIDER] + + else: + + # Get number of spectra, checking attributes for {Data3D, Data2DT, Data2D, Data} + if getattr(data, 'n_events', None): + n_spectra_str = '{} spectrograms with {} windows each'.format(data.n_events, data.n_time_windows) + elif getattr(data, 'n_time_windows', None): + n_spectra_str = '1 spectrogram with {} windows'.format(data.n_time_windows) + elif getattr(data, 'n_spectra', None): + n_spectra_str = '{} power spectra'.format(data.n_spectra) + else: + n_spectra_str = '1 power spectrum' + + str_lst = [ + + DIVIDER, + '', + 'The data object contains {}'.format(n_spectra_str), + 'with a frequency range of {} Hz'.format(data.freq_range), + 'and a frequency resolution of {} Hz.'.format(data.freq_res), + '', + DIVIDER, + ] + + output = _format(str_lst, concise) + + return output + + def gen_modes_str(modes, description=False, concise=False): """Generate a string representation of fit modes. diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index 866dc9c8..d3b3a9d9 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -8,9 +8,9 @@ from specparam.modutils.dependencies import safe_import -from specparam.tests.tdata import (get_tdata, get_tdata2d, get_tfm, get_tfm2, get_tfg, get_tfg2, - get_tft, get_tfe, get_tbands, get_tresults, get_tmodes, - get_tdocstring) +from specparam.tests.tdata import (get_tdata, get_tdata2d, get_tdata2dt, get_tdata3d, + get_tfm, get_tfm2, get_tfg, get_tfg2, get_tft, get_tfe, + get_tbands, get_tresults, get_tmodes, get_tdocstring) from specparam.tests.tsettings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -67,6 +67,14 @@ def tdata(): def tdata2d(): yield get_tdata2d() +@pytest.fixture(scope='session') +def tdata2dt(): + yield get_tdata2dt() + +@pytest.fixture(scope='session') +def tdata3d(): + yield get_tdata3d() + @pytest.fixture(scope='session') def tfm(): yield get_tfm() diff --git a/specparam/tests/data/test_data.py b/specparam/tests/data/test_data.py index 503c3bd4..194e6ae6 100644 --- a/specparam/tests/data/test_data.py +++ b/specparam/tests/data/test_data.py @@ -16,6 +16,8 @@ def test_data(): tdata = Data() assert tdata + assert not tdata.has_data + assert not tdata.n_freqs def test_data_add_data(): @@ -23,6 +25,7 @@ def test_data_add_data(): freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) tdata.add_data(freqs, pows) assert tdata.has_data + assert tdata.n_freqs == len(freqs) def test_data_meta_data(): @@ -66,6 +69,7 @@ def test_data2d(): assert tdata2d assert isinstance(tdata2d, Data) assert isinstance(tdata2d, Data2D) + assert not tdata2d.has_data def test_data2d_add_data(): @@ -73,6 +77,7 @@ def test_data2d_add_data(): freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]) tdata2d.add_data(freqs, pows) assert tdata2d.has_data + assert tdata2d.n_spectra == len(pows) @plot_test def test_data2d_plot(tdata2d, skip_if_no_mpl): @@ -88,6 +93,7 @@ def test_data2dt(): assert isinstance(tdata2dt, Data) assert isinstance(tdata2dt, Data2D) assert isinstance(tdata2dt, Data2DT) + assert not tdata2dt.has_data def test_data2dt_add_data(): @@ -96,7 +102,7 @@ def test_data2dt_add_data(): tdata2dt.add_data(freqs, pows) assert tdata2dt.has_data assert np.all(tdata2dt.spectrogram) - assert tdata2dt.n_time_windows + assert tdata2dt.n_spectra == tdata2dt.n_time_windows == len(pows.T) ## 3D Data Object @@ -108,6 +114,7 @@ def test_data3d(): assert isinstance(tdata3d, Data2D) assert isinstance(tdata3d, Data2DT) assert isinstance(tdata3d, Data3D) + assert not tdata3d.has_data def test_data3d_add_data(): @@ -117,3 +124,4 @@ def test_data3d_add_data(): assert tdata3d.has_data assert np.all(tdata3d.spectrograms) assert tdata3d.n_events + assert tdata3d.n_spectra == 2 * len(pows.T) diff --git a/specparam/tests/reports/test_strings.py b/specparam/tests/reports/test_strings.py index 81765443..8833cf85 100644 --- a/specparam/tests/reports/test_strings.py +++ b/specparam/tests/reports/test_strings.py @@ -14,6 +14,13 @@ def test_gen_version_str(): assert gen_version_str() +def test_gen_data_str(tdata, tdata2d, tdata2dt, tdata3d): + + assert gen_data_str(tdata) + assert gen_data_str(tdata2d) + assert gen_data_str(tdata2dt) + assert gen_data_str(tdata3d) + def test_gen_modes_str(tfm): assert gen_modes_str(tfm.modes) diff --git a/specparam/tests/tdata.py b/specparam/tests/tdata.py index 936af8fe..85140d7f 100644 --- a/specparam/tests/tdata.py +++ b/specparam/tests/tdata.py @@ -4,7 +4,7 @@ from specparam.bands import Bands from specparam.modes.modes import Modes -from specparam.data.data import Data, Data2D +from specparam.data.data import Data, Data2D, Data2DT, Data3D from specparam.data.stores import FitResults from specparam.models import (SpectralModel, SpectralGroupModel, SpectralTimeModel, SpectralTimeEventModel) @@ -50,6 +50,26 @@ def get_tdata2d(): return tdata2d +def get_tdata2dt(): + + n_spectra = 3 + tdata2dt = Data2DT() + tdata2dt.add_data(*sim_spectrogram(n_spectra, *default_group_params())) + + return tdata2dt + +def get_tdata3d(): + + n_events = 2 + n_spectra = 3 + tdata3d = Data3D() + freqs, spectrogram = sim_spectrogram(n_spectra, *default_group_params()) + tdata3d.add_data(freqs, [spectrogram] * n_events) + + return tdata3d + +## TEST MODEL OBJECTS + def get_tfm(): """Get a model object, with a fit power spectrum, for testing.""" @@ -117,6 +137,8 @@ def get_tfe(): return tfe +## TEST OTHER OBJECTS + def get_tbands(): """Get a bands object, for testing."""