Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 0 additions & 101 deletions pylabianca/scripts/test_region_changes.py

This file was deleted.

66 changes: 31 additions & 35 deletions pylabianca/selectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd

from .analysis import nested_groupby_apply
from .utils import (find_nested_dims, cellinfo_from_xarray,
_inherit_metadata_from_xarray, assign_session_coord)
from .utils import (cellinfo_from_xarray, _inherit_metadata_from_xarray,
assign_session_coord)


# TODO: ! adapt for multiple cells
Expand Down Expand Up @@ -153,76 +153,72 @@ def compute_selectivity_continuous(frate, compare='image', n_perm=500,
from .stats import permutation_test

has_time = 'time' in frate.dims

frate_dims = ['trial', 'cell']
if has_time:
frate_dims.append('time')
frate_dims = ['trial', 'cell', 'time'] if has_time else ['trial', 'cell']
frate = frate.transpose(*frate_dims)

# permutations
# ------------
arrs = [arr for _, arr in frate.groupby(compare)]
stat_name = 't value' if len(arrs) == 2 else 'F value'
stat_unit = stat_name[0]

results = permutation_test(
*arrs, paired=False, n_perm=n_perm,
return_pvalue=False, return_distribution=True, n_jobs=n_jobs)

# turn to xarray
# --------------
cells = frate.cell.values

# perm
dims = ['perm'] + frate_dims[1:]
is_dict = isinstance(results, dict)
has_dist = is_dict and 'dist' in results
has_thresh = is_dict and 'thresh' in results
coords = {dim: frate.coords[dim].values.copy() for dim in frate_dims[1:]}

if n_perm > 0:
results['dist'] = xr.DataArray(data=results['dist'], dims=dims,
coords=coords, name=stat_name)
use_data = results['stat']
else:
use_data = results
results = dict()

# stat
use_data = results['stat'] if is_dict else results
results = results if is_dict else dict()

results['stat'] = xr.DataArray(
data=use_data, dims=dims[1:], coords=coords, name=stat_name)
data=use_data, dims=frate_dims[1:], coords=coords, name=stat_name)

# perm
if has_dist:
dims = ['perm'] + frate_dims[1:]
results['dist'] = xr.DataArray(data=results['dist'], dims=dims,
coords=coords, name=stat_name)

# thresh
if n_perm > 0:
if has_thresh:
if isinstance(results['thresh'], list) and len(results['thresh']) == 2:
# two-tail thresholds
dims2 = ['tail'] + dims[1:]
dims2 = ['tail'] + frate_dims[1:]
results['thresh'] = np.stack(results['thresh'], axis=0)
coords.update({'tail': ['pos', 'neg']})
else:
dims2 = dims[1:]
dims2 = frate_dims[1:]
coords.update({'tail': np.array('pos')})

results['thresh'] = xr.DataArray(
data=results['thresh'], dims=dims2, coords=coords, name=stat_name)

# copy unit information
# TODO: use a separate utility function
for key in results.keys():
results[key].attrs['unit'] = stat_unit
if 'coord_units' in frate.attrs:
results[key].attrs['coord_units'] = frate.attrs['coord_units']

# add cell coords
# TODO: move after Dataset creation
copy_coords = find_nested_dims(frate, 'cell')
if len(copy_coords) > 0:
for key in results.keys():
results[key] = _inherit_metadata_from_xarray(
frate, results[key], 'cell', copy_coords=copy_coords)

# transform to dataset:
_copy_dim_units(frate, results[key])

# transform to dataset and add cell coords:
results = xr.Dataset(results)
results = _inherit_metadata_from_xarray(frate, results, 'cell')

return results


def _copy_dim_units(xarr_from, xarr_to):
for dim in xarr_from.dims:
field = f'{dim}_unit'
if dim in xarr_to.dims and field in xarr_from.attrs:
xarr_to.attrs[field] = xarr_from.attrs[field]


# pbar is now True, which defaults to text tqdm, but could be 'auto'
# TODO: create more progress bars and pass to cluster_based_test
def cluster_based_selectivity(frate, compare, cluster_entry_pval=0.05,
Expand Down
10 changes: 7 additions & 3 deletions pylabianca/utils/xarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def _turn_spike_rate_to_xarray(times, frate, spike_epochs, cell_names=None,
coords = _inherit_metadata(
coords, spike_epochs.metadata, dimname, tri=tri)

# TODO: it seems that this copy_cellinfo part
# could also use _inherit_metadata
if copy_cellinfo:
if cell_names is not None and spike_epochs.cellinfo is not None:
ch_idx = _deal_with_picks(spike_epochs, cell_names)
Expand Down Expand Up @@ -156,18 +158,20 @@ def cellinfo_from_xarray(xarr):
return cellinfo


def _inherit_metadata(coords, metadata, dimname, tri=None):
def _inherit_metadata(coords, metadata, dim_name, tri=None):
'''Inherit metadata from a DataFrame to xarray coordinates.'''
if metadata is not None:
for col in metadata.columns:
if tri is None:
coords[col] = (dimname, metadata[col])
coords[col] = (dim_name, metadata[col])
else:
coords[col] = (dimname, metadata[col].iloc[tri])
coords[col] = (dim_name, metadata[col].iloc[tri])
return coords


def _inherit_metadata_from_xarray(xarr_from, xarr_to, dimname,
copy_coords=None):
'''Inherit metadata from one xarray to another.'''
if copy_coords is None:
copy_coords = find_nested_dims(xarr_from, dimname)
if len(copy_coords) > 0:
Expand Down