Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ppseq/batch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def __init__(self,
beta_a0: float=0.,
alpha_b0: float=0.,
beta_b0: float=0.,
alpha_t0: float=0.,
beta_t0:float=0.,
device=None):
super().__init__(num_templates,
num_neurons,
Expand All @@ -26,6 +28,8 @@ def __init__(self,
beta_a0,
alpha_b0,
beta_b0,
alpha_t0,
beta_t0,
device)

def fit(self,
Expand All @@ -40,7 +44,7 @@ def fit(self,


init_method = dict(random=self.initialize_random)[initialization.lower()]
amplitude_batches =[init_method(data) for data in data_batches]
amplitude_batches =[init_method(data.squeeze()) for data in data_batches]

# TODO: Initialize amplitudes more intelligently?
# amplitudes = torch.rand(K, T, device=self.device) + 1e-4
Expand All @@ -50,12 +54,13 @@ def fit(self,
for _ in progress_bar(range(num_iter)):
ll = 0
for i, data in enumerate(data_batches):
data = data.squeeze() # prevents indexing error when data_shape = (1, N, T) (e.g in a torch dataloader)
amplitude_batches[i] = self._update_amplitudes(data,
amplitude_batches[i])
self._update_base_rates(data, amplitude_batches[i])
self._update_templates(data, amplitude_batches[i])
ll += self.log_likelihood(data, amplitude_batches[i])
lps.append(ll)
lps.append(ll) #return the sum or avg log likelihood?

lps = torch.stack(lps) if num_iter > 0 else torch.tensor([])
return lps, amplitude_batches
Expand Down
13 changes: 13 additions & 0 deletions ppseq/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self,


self.base_rates = torch.ones(num_neurons, device=device)

self.template_scales = torch.ones(num_templates, num_neurons, device=device) / num_neurons
self.template_offsets = template_duration * torch.rand(num_templates, num_neurons, device=device)
self.template_widths = torch.ones(self.num_templates, self.num_neurons, device=device)
Expand All @@ -59,6 +60,17 @@ def __init__(self,
self.alpha_t0 = alpha_t0
self.beta_t0 = beta_t0

def to(self, map_location: str | torch.DeviceObjType | torch.dtype):

self.base_rates = self.base_rates.to(map_location)
self.template_scales = self.template_scales.to(map_location)
self.template_offsets = self.template_offsets.to(map_location)
self.template_widths = self.template_widths.to(map_location)

if not isinstance(map_location, torch.dtype):
self.device=map_location


@property
def templates(self) -> Float[Tensor, "num_templates num_neurons duration"]:
"""Compute the templates from the mean, std, and amplitude of the Gaussian kernel.
Expand All @@ -68,6 +80,7 @@ def templates(self) -> Float[Tensor, "num_templates num_neurons duration"]:
ds = torch.arange(D, device=self.device)[:, None, None]
p = dist.Normal(mu, sigma)
W = p.log_prob(ds).exp().permute(1,2,0)

return W / W.sum(dim=2, keepdim=True) * amp[:, :, None]

def reconstruct(self,
Expand Down