Skip to content

Commit cff5411

Browse files
committed
updated make_spiketrain_tools
1 parent 133a9c7 commit cff5411

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

expipe/analysis/general/plot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
def plot_raster(trials, color='b', lw=1, ax=None, marker='.', marker_size=10,
8-
ylabel='trials', id_start=0, ylim=None, dim='s'):
8+
ylabel='Trials', id_start=0, ylim=None, dim='s'):
99
"""
1010
Raster plot of trials
1111
@@ -48,6 +48,7 @@ def plot_raster(trials, color='b', lw=1, ax=None, marker='.', marker_size=10,
4848
t_start = trials[0].t_start.rescale(dim)
4949
t_stop = trials[0].t_stop.rescale(dim)
5050
ax.set_xlim([t_start, t_stop])
51+
ax.set_xlabel("Times ["+dim+"]")
5152
if ylabel is not None:
5253
ax.set_ylabel(ylabel)
5354
return ax

expipe/analysis/stimulus/tools.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,22 @@ def make_spiketrain_trials(epo, t_start, t_stop, unit=None, sptr=None):
124124
-------
125125
out : list of neo.SpikeTrains
126126
'''
127-
assert t_start != t_stop, 't_start cannot be equal to t_stop'
128-
dim = 's'
129-
t_start = t_start.rescale(dim)
130-
t_stop = t_stop.rescale(dim)
131127
from neo.core import SpikeTrain
128+
129+
if t_start.ndim == 0:
130+
t_starts = t_start * np.ones(len(epo.times))
131+
else:
132+
t_starts = t_start
133+
assert len(epo.times) == len(t_starts), 'epo.times and t_starts have different size'
134+
135+
if t_stop.ndim == 0:
136+
t_stops = t_stop * np.ones(len(epo.times))
137+
else:
138+
t_stops = epo.durations
139+
assert len(epo.times) == len(t_stops), 'epo.times and t_stops have different size'
140+
141+
dim = 's'
142+
132143
if sptr is None:
133144
assert unit is not None, 'unit and st cannot be both None'
134145
sptr = []
@@ -139,6 +150,8 @@ def make_spiketrain_trials(epo, t_start, t_stop, unit=None, sptr=None):
139150
sptr = sptr.rescale(dim)
140151
trials = []
141152
for j, t in enumerate(epo.times.rescale(dim)):
153+
t_start = t_starts[j].rescale(dim)
154+
t_stop = t_stops[j].rescale(dim)
142155
spikes = []
143156
for spike in sptr[(t+t_start < sptr) & (sptr < t+t_stop)]:
144157
spikes.append(spike-t)

0 commit comments

Comments
 (0)