Skip to content
39 changes: 21 additions & 18 deletions pylabianca/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,9 @@ def shuffle_trials(spk, drop_timestamps=True, drop_waveforms=True):


# TODO: change name to something more descriptive like "select_data"?
# CONSIDER: ses_name -> ses_coord ?
# CONSIDER: change the loop to use .groupby() xarr method instead of _get_arr
# (might be faster)
def extract_data(xarr_dict, df, sub_col='sub', ses_col=None, ses_name='sub',
def extract_data(xarr_dict, df, sub_col='sub', ses_col=None, ses_coord='sub',
df2xarr=None):
'''Extract data from xarray dictionary using a dataframe.

Expand All @@ -258,6 +257,9 @@ def extract_data(xarr_dict, df, sub_col='sub', ses_col=None, ses_name='sub',
information.
ses_col : str | None
Name of the column in the DataFrame that contains session information.
ses_coord : str
Name of the subject / session coordinate in the xarray. Defaults to
``'sub'``.
df2xarr : dict | None
Dictionary that maps DataFrame columns to xarray coordinates. If None,
the default is ``{'label': 'region'}``.
Expand All @@ -281,7 +283,7 @@ def extract_data(xarr_dict, df, sub_col='sub', ses_col=None, ses_name='sub',
xarr_out = dict()
else:
import pandas as pd
keys = pd.unique(xarr_dict.coords[ses_name].values)
keys = pd.unique(xarr_dict.coords[ses_coord].values)
xarr_out = list()

# TODO - check for sub / ses consistency and raise / warn
Expand Down Expand Up @@ -309,7 +311,7 @@ def extract_data(xarr_dict, df, sub_col='sub', ses_col=None, ses_name='sub',
if ses is not None and ses_col is not None:
df_sel = df_sel.query(f'{ses_col} == "{ses}"')

xarr = _get_arr(xarr_dict, key, ses_name=ses_name)
xarr = _get_arr(xarr_dict, key, ses_coord=ses_coord)

n_cells = len(xarr.coords['cell'])
mask_all = np.zeros(n_cells, dtype=bool)
Expand Down Expand Up @@ -341,12 +343,12 @@ def extract_data(xarr_dict, df, sub_col='sub', ses_col=None, ses_name='sub',
return xarr_out, row_indices


def _get_arr(arr, sub_ses, ses_name='sub'):
def _get_arr(arr, sub_ses, ses_coord='sub'):
import xarray as xr
if isinstance(arr, dict):
arr = arr[sub_ses]
elif isinstance(arr, xr.DataArray):
arr = arr.query({'cell': f'{ses_name} == "{sub_ses}"'})
arr = arr.query({'cell': f'{ses_coord} == "{sub_ses}"'})
return arr


Expand All @@ -356,6 +358,9 @@ def _get_arr(arr, sub_ses, ses_name='sub'):
# - [ ] better argument names:
# -> is per_cell_query (per_cell_select etc.) even needed is we
# have per_cell=True and pass to specific subfunction?
# (I don't know what I meant by "pass to psecific subfunction"..)
# - [ ] select vs per_cell_query behavior -> first one is done after zscoring
# the second one before
# ? option to pass the baseline calculated from a different period
def aggregate(frate, groupby=None, select=None, per_cell_query=None,
zscore=False, baseline=False, per_cell=False):
Expand Down Expand Up @@ -479,9 +484,9 @@ def _aggregate_xarray(frate, groupby, zscore, select, baseline):
frate : xarray.DataArray
Aggregated firing rate data.
"""

if zscore:
bsln = None if isinstance(zscore, bool) else zscore
is_zscore_bool = isinstance(zscore, bool)
if not is_zscore_bool or zscore:
bsln = None if is_zscore_bool else zscore
frate = zscore_xarray(frate, baseline=bsln)

if select is not None:
Expand Down Expand Up @@ -622,8 +627,7 @@ def zscore_xarray(arr, groupby='cell', baseline=None):
return arr


# CONSIDER: ses_name -> ses_coord ?
def dict_to_xarray(data, dim_name='cell', select=None, ses_name='sub'):
def dict_to_xarray(data, dim_name='cell', select=None, ses_coord='sub'):
'''Convert dictionary to xarray.DataArray.

Parameters
Expand All @@ -642,7 +646,7 @@ def dict_to_xarray(data, dim_name='cell', select=None, ses_name='sub'):
concatenation some coordinates may become multi-dimensional and
querying would raise an error "Unlabeled multi-dimensional array cannot
be used for indexing").
ses_name : str
ses_coord : str
Name of the subject / session coordinate that will be automatically
added to the concatenated dimension from the dictionary keys. Defaults
to ``'sub'``.
Expand Down Expand Up @@ -675,12 +679,12 @@ def dict_to_xarray(data, dim_name='cell', select=None, ses_name='sub'):
arr = arr.query(select)

# if trial was in select dict, then we should reset trial indices
if 'trial' in select:
if 'trial' in select and 'trial' in list(arr.coords):
arr = arr.reset_index('trial', drop=True)

# add subject / session information to the concatenated dimension
arr = assign_session_coord(
arr, key, dim_name=dim_name, ses_name=ses_name)
arr, key, dim_name=dim_name, ses_coord=ses_coord)

for coord_name, coord in arr.coords.items():
if coord_name not in all_coord_dims:
Expand Down Expand Up @@ -715,8 +719,7 @@ def _get_missing_value(dtype):
return None


# CONSIDER: ses_name -> ses_coord ?
def xarray_to_dict(xarr, ses_name='sub', reduce_coords=True,
def xarray_to_dict(xarr, ses_coord='sub', reduce_coords=True,
ensure_correct_reduction=True):
'''Convert multi-session xarray to dictionary of session -> xarray pairs.

Expand All @@ -727,7 +730,7 @@ def xarray_to_dict(xarr, ses_name='sub', reduce_coords=True,
----------
xarr : xarray.DataArray
Multi-session DataArray.
ses_name : str
ses_coord : str
Name of the session coordinate. Defaults to ``'sub'``.
reduce_coords : bool
If True, reduce coordinates were turned to cell x trial coordinates
Expand All @@ -752,7 +755,7 @@ def xarray_to_dict(xarr, ses_name='sub', reduce_coords=True,
'''
xarr_dct = dict()

sessions, ses_idx = np.unique(xarr.coords[ses_name].values, return_index=True)
sessions, ses_idx = np.unique(xarr.coords[ses_coord].values, return_index=True)

sort_idx = np.argsort(ses_idx)
sessions = sessions[sort_idx]
Expand Down
2 changes: 1 addition & 1 deletion pylabianca/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def _waveform_to_ft(spk, spikeTrials):
n_units = spk.n_units()
spikeTrials['waveform'] = np.empty(n_units, dtype='object')
for cell_idx in range(n_units):
# add "leads" dimention
# add "leads" dimension
this_waveform = spk.waveform[cell_idx]
if this_waveform is not None:
this_waveform = this_waveform.T[None, :]
Expand Down
4 changes: 2 additions & 2 deletions pylabianca/selectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ def compute_percent_selective(selectivity, threshold=None, dist=None,


# TODO: create apply_dict function (with out_type='dict' or 'xarray' etc.)
# TODO: expose ses_name as a parameter
# TODO: expose ses_coord as a parameter
def compute_selectivity_multisession(frate, compare=None, select=None,
n_perm=1_000, n_jobs=1):
"""
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def compute_selectivity_multisession(frate, compare=None, select=None,
if select is not None:
fr = fr.query({'trial': select})

fr = assign_session_coord(fr, ses, dim_name='cell', ses_name='sub')
fr = assign_session_coord(fr, ses, dim_name='cell', ses_coord='sub')

results = compute_selectivity_continuous(
fr, compare=compare, n_perm=n_perm, n_jobs=n_jobs)
Expand Down
6 changes: 3 additions & 3 deletions pylabianca/test/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_spike_centered_windows():

# make sure that using pln.utils version raises warning,
# but gives the same result
with pytest.warns(DeprecationWarning):
with pytest.warns(FutureWarning):
spk_cent3 = pln.utils.spike_centered_windows(
spk, xarr, winlen=0.01)

Expand Down Expand Up @@ -187,8 +187,8 @@ def compare_dicts(x_dct1, x_dct2):
compare_dicts(x_dct1, x_dct2)

# make sure we can do the same via pln.utils,
# but with a deprecation warning
with pytest.warns(DeprecationWarning):
# but with a FutureWarning
with pytest.warns(FutureWarning):
xarr3 = pln.utils.dict_to_xarray(x_dct1)
x_dct3 = pln.utils.xarray_to_dict(xarr3)

Expand Down
48 changes: 47 additions & 1 deletion pylabianca/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pylabianca as pln
from pylabianca.utils import (_get_trial_boundaries, find_cells,
create_random_spikes, _inherit_metadata)
from pylabianca.testing import gen_random_xarr


def test_trial_boundaries():
Expand Down Expand Up @@ -250,7 +251,6 @@ def test_turn_spike_rate_to_xarray():

def test_find_nested_dims():
import xarray as xr
from pylabianca.testing import gen_random_xarr

n_cells, n_trials, n_times = 5, 24, 100
tri_coord = np.random.choice(list('abcd'), size=n_trials)
Expand All @@ -264,3 +264,49 @@ def test_find_nested_dims():
assert isinstance(sub_dims, list)
assert len(sub_dims) == 1
assert 'cond' in sub_dims


def test_assign_session_coord():
"""Test assign_session_coord function with various scenarios."""
import xarray as xr

# Test 1: Basic functionality with cell dimension
n_cells, n_trials, n_times = 5, 24, 100
xarr = gen_random_xarr(n_cells, n_trials, n_times)
session_name = 'session_A'

result = pln.utils.xarr.assign_session_coord(xarr, session_name)
assert 'session' in result.coords
assert (result.coords['session'].values == session_name).all()
assert len(result.coords['session']) == n_cells

# Test 2: Custom ses_coord name
custom_coord_name = 'my_session'
result = pln.utils.xarr.assign_session_coord(
xarr, session_name, ses_coord=custom_coord_name
)
assert custom_coord_name in result.coords
assert (result.coords[custom_coord_name].values == session_name).all()

# Test 3: FutureWarning when using ses_name parameter
with pytest.warns(FutureWarning, match='`ses_name` is deprecated'):
result = pln.utils.xarr.assign_session_coord(
xarr, session_name, ses_name='deprecated_session'
)
assert 'deprecated_session' in result.coords

# Test 4: Function works when cell is a coordinate but not in dims
# (simulating arr.isel(cell=1))
xarr_selected = xarr.isel(cell=1)
assert 'cell' in xarr_selected.coords
assert 'cell' not in xarr_selected.dims

result = pln.utils.xarr.assign_session_coord(xarr_selected, session_name)
assert 'session' in result.coords
# When cell is not in dims, n_cells should be 1
assert len(result.coords['session']) == 1
assert result.coords['session'].values[0] == session_name

# Test 5: ValueError when dim_name not found
with pytest.raises(ValueError, match='Could not find dim_name'):
pln.utils.xarr.assign_session_coord(xarr, session_name, dim_name='nonexistent')
30 changes: 20 additions & 10 deletions pylabianca/utils/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ def dict_to_xarray(*args, **kwargs):
instead.'''

# raise deprecation warning
warnings.warn('This function has moved. Use `pylabianca.analysis.dict_'
'to_xarray` instead.', DeprecationWarning)
warnings.warn('`pylabianca.utils.dict_to_xarray` has moved to '
'`pylabianca.analysis.dict_to_xarray`. The old import path '
'will be removed in a future version.', FutureWarning, stacklevel=3)

from ..analysis import dict_to_xarray as _dict_to_xarray
return _dict_to_xarray(*args, **kwargs)
Expand All @@ -18,8 +19,10 @@ def xarray_to_dict(*args, **kwargs):
instead.'''

# raise deprecation warning
warnings.warn('This function has moved. Use `pylabianca.analysis.xarray_'
'to_dict` instead.', DeprecationWarning)
warnings.warn('`pylabianca.analysis.xarray_to_dict` has moved to '
'`pylabianca.analysis.xarray_to_dict`. instead. The old '
'import path will be removed in a future version.',
FutureWarning, stacklevel=3)

from ..analysis import xarray_to_dict as _xarray_to_dict
return _xarray_to_dict(*args, **kwargs)
Expand All @@ -30,8 +33,10 @@ def spike_centered_windows(*args, **kwargs):
``pylabianca.analysis.spike_centered_windows`` instead.'''

# raise deprecation warning
warnings.warn('This function has moved. Use `pylabianca.analysis.spike_'
'centered_windows` instead.', DeprecationWarning)
warnings.warn('`pylabianca.analysis.spike_centered_windows` has moved to '
'`pylabianca.analysis.spike_centered_windows`. instead. The old '
'import path will be removed in a future version.',
FutureWarning, stacklevel=3)

from ..analysis import spike_centered_windows as _spike_centered_windows
return _spike_centered_windows(*args, **kwargs)
Expand All @@ -42,8 +47,10 @@ def shuffle_trials(*args, **kwargs):
instead.'''

# raise deprecation warning
warnings.warn('This function has moved. Use `pylabianca.analysis.shuffle_'
'trials` instead.', DeprecationWarning)
warnings.warn('`pylabianca.analysis.shuffle_trials` has moved to '
'`pylabianca.analysis.shuffle_trials`. instead. The old '
'import path will be removed in a future version.',
FutureWarning, stacklevel=3)

from ..analysis import shuffle_trials as _shuffle_trials
return _shuffle_trials(*args, **kwargs)
Expand All @@ -53,8 +60,11 @@ def read_drop_info(*args, **kwargs):
'''This function has moved. Use ``pylabianca.io.read_drop_info`` instead.'''

# raise deprecation warning
warnings.warn('This function has moved. Use `pylabianca.io.read_drop_'
'info` instead.', DeprecationWarning)
warnings.warn('`pylabianca.analysis.read_drop_info` has moved to '
'`pylabianca.analysis.read_drop_info`. instead. The old '
'import path will be removed in a future version.',
FutureWarning, stacklevel=3)


from ..io import read_drop_info as _read_drop_info
return _read_drop_info(*args, **kwargs)
43 changes: 39 additions & 4 deletions pylabianca/utils/xarr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
import numpy as np
from .base import _deal_with_picks

Expand Down Expand Up @@ -204,9 +205,43 @@ def find_nested_dims(arr, dim_name):
return names


# CONSIDER: ses_name -> ses_coord ?
def assign_session_coord(arr, ses, dim_name='cell', ses_name='session'):
n_cells = len(arr.coords[dim_name])
def assign_session_coord(arr, ses, dim_name='cell', ses_coord='session',
ses_name=None):
'''Assign a coordinate with session info to all cells.

Parameters
----------
arr : xarray.DataArray
Input xarray.
ses : str
Session name to assign.
dim_name : str
Name of the dimension corresponding to cells.
ses_name : str
Name of the session coordinate to create.

Returns
-------
arr : xarray.DataArray
Xarray with session coordinate assigned.
'''
# deprecate ses_name in favor of ses_coord
if ses_name is not None:
ses_coord = ses_name
warnings.warn('`ses_name` is deprecated and will be removed in a '
'future release. Use `ses_coord` instead.',
FutureWarning, stacklevel=2)

# check dim_name
if dim_name in arr.dims:
n_cells = len(arr.coords[dim_name])
elif dim_name in arr.coords:
n_cells = 1
arr = arr.expand_dims(dim_name, axis=0)
else:
raise ValueError(f'Could not find dim_name "{dim_name}" in arr.dims'
'or arr.coords.')

sub_dim = [ses] * n_cells
arr = arr.assign_coords({ses_name: (dim_name, sub_dim)})
arr = arr.assign_coords({ses_coord: (dim_name, sub_dim)})
return arr