Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tensorpac/pac.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def filter(self, sf, x, ftype='phase', keepfilt=False, edges=None,
assert isinstance(sf, (int, float)), ("The sampling frequency must be "
"a float number.")
# Compatibility between keepfilt and wavelet :
if (keepfilt is True) and (self._dcomplex is 'wavelet'):
if (keepfilt is True) and (self._dcomplex == 'wavelet'):
raise ValueError("Using wavelet for the complex decomposition do "
"not allow to get filtered data only. Set the "
"keepfilt parameter to False or set dcomplex to "
Expand All @@ -86,11 +86,11 @@ def filter(self, sf, x, ftype='phase', keepfilt=False, edges=None,

# ---------------------------------------------------------------------
# Switch between phase or amplitude :
if ftype is 'phase':
if ftype == 'phase':
tosend = 'pha' if not keepfilt else None
xfilt = spectral(x, sf, self.f_pha, tosend, self._dcomplex,
self._cycle[0], self._width, n_jobs)
elif ftype is 'amplitude':
elif ftype == 'amplitude':
tosend = 'amp' if not keepfilt else None
xfilt = spectral(x, sf, self.f_amp, tosend, self._dcomplex,
self._cycle[1], self._width, n_jobs)
Expand Down Expand Up @@ -155,7 +155,7 @@ def _infer_pvalues(self, effect, perm, p=.05, mcp='maxstat'):
# ---------------------------------------------------------------------
logger.info(f" infer p-values at (p={p}, mcp={mcp})")
# computes the pvalues
if mcp is 'maxstat':
if mcp == 'maxstat':
max_p = perm.reshape(n_perm, -1).max(1)[np.newaxis, ...]
nb_over = (effect[..., np.newaxis] <= max_p).sum(-1)
pvalues = nb_over / n_perm
Expand All @@ -164,7 +164,7 @@ def _infer_pvalues(self, effect, perm, p=.05, mcp='maxstat'):
pvalues = np.maximum(1. / n_perm, pvalues)
elif mcp in ['fdr', 'bonferroni']:
from mne.stats import fdr_correction, bonferroni_correction
fcn = fdr_correction if mcp is 'fdr' else bonferroni_correction
fcn = fdr_correction if mcp == 'fdr' else bonferroni_correction
# compute the p-values
pvalues = (effect[np.newaxis, ...] <= perm).sum(0) / n_perm
pvalues = np.maximum(1. / n_perm, pvalues)
Expand Down
8 changes: 4 additions & 4 deletions tensorpac/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def spectral(x, sf, f, stype, dcomplex, cycle, width, n_jobs):
"""
n_freqs = f.shape[0]
# Filtering + complex decomposition :
if dcomplex is 'hilbert':
if dcomplex == 'hilbert':
# get filtering coefficients
b = []
a = np.zeros((n_freqs,), dtype=float)
Expand All @@ -58,15 +58,15 @@ def spectral(x, sf, f, stype, dcomplex, cycle, width, n_jobs):
xd = np.asarray(xf)
if stype is not None:
xd = hilbertm(xd)
elif dcomplex is 'wavelet':
elif dcomplex == 'wavelet':
f = f.mean(1) # centered frequencies
xd = Parallel(n_jobs=n_jobs, **CONFIG['JOBLIB_CFG'])(delayed(morlet)(
x, sf, k, width) for k in f)

# Extract phase / amplitude :
if stype is 'pha':
if stype == 'pha':
return np.angle(xd).astype(np.float64)
elif stype is 'amp':
elif stype == 'amp':
return np.abs(xd).astype(np.float64)
elif stype is None:
return xd.astype(np.float64)
Expand Down
6 changes: 3 additions & 3 deletions tensorpac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _check_freq(f):
if len(f.reshape(-1)) == 1:
raise ValueError("The length of f should at least be 2.")
elif 2 in f.shape: # f of shape (N, 2) or (2, N)
if f.shape[1] is not 2:
if f.shape[1] != 2:
f = f.T
elif np.squeeze(f).shape == (4,): # (f_start, f_end, f_width, f_step)
f = _pair_vectors(*tuple(np.squeeze(f)))
Expand Down Expand Up @@ -202,7 +202,7 @@ def plot(self, f_min=None, f_max=None, confidence=95, interp=None,
plt.xlim(f_min, f_max)
if log:
from matplotlib.ticker import ScalarFormatter
plt.xscale('log', basex=10)
plt.xscale('log', base=10)
plt.gca().xaxis.set_major_formatter(ScalarFormatter())
if grid:
plt.grid(color='grey', which='major', linestyle='-',
Expand Down Expand Up @@ -262,7 +262,7 @@ def plot_st_psd(self, f_min=None, f_max=None, log=False, grid=True,
_viz.pacplot(psd, xvec, trials, **kw)
if log:
from matplotlib.ticker import ScalarFormatter
plt.xscale('log', basex=10)
plt.xscale('log', base=10)
plt.gca().xaxis.set_major_formatter(ScalarFormatter())
if grid:
plt.grid(color='grey', which='major', linestyle='-',
Expand Down