Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 112 additions & 2 deletions src/neat/trees/compartmenttree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,9 +1801,119 @@ def compute_c(self, alphas, phimat, weights=None, tau_eps=5.0):

self._to_tree_c(c_vec)

# def fit_capacitances_with_matrix_exp(
# self, times: np.ndarray, target_kernels: np.ndarray,
# lr=1e-2, steps=1000, device="cpu",
# return_loss_history=False, early_stopping=False, patience=50, tol=1e-6
# ):
# """
# Fit compartment capacitances to match input resistance kernels
# from GreensTreeTime using matrix exponentials.

# Parameters
# ----------
# times : np.ndarray, shape (T,)
# Time points at which kernels are evaluated.
# target_kernels : np.ndarray, shape (T, n_comp)
# Target diagonal kernels from GreensTreeTime
# (response at i from injection at i).
# lr : float
# Learning rate for optimizer.
# steps : int
# Maximum number of optimization steps.
# device : str
# PyTorch device ("cpu" or "cuda").
# return_loss_history : bool, default False
# If True, also return the list of loss values over optimization.
# early_stopping : bool, default False
# If True, stop optimization early if the loss does not improve.
# patience : int, default 50
# Number of steps to wait for improvement before stopping.
# tol : float, default 1e-6
# Minimum improvement in loss to reset patience.

# Returns
# -------
# kernels_pred : np.ndarray, shape (T, n_comp)
# Predicted fitted diagonal kernels.
# loss_history : list of float, optional
# Loss values at each optimization step (only if return_loss_history=True).
# """
# import torch
# n_comp = len(self)

# # Convert inputs to torch tensors
# times = torch.tensor(times, dtype=torch.float32, device=device)
# target_kernels = torch.tensor(target_kernels, dtype=torch.float32, device=device)

# # Initial capacitance vector from tree
# c_init = self._to_c_vec() # numpy array (n_comp,)
# log_C = torch.nn.Parameter(
# torch.log(torch.tensor(c_init, dtype=torch.float32, device=device))
# )
# optimizer = torch.optim.Adam([log_C], lr=lr)

# # Use NEAT's conductance matrix function
# G_np = self.calc_conductance_matrix() # numpy array (n_comp, n_comp)
# G = torch.tensor(G_np, dtype=torch.float32, device=device)

# loss_history = []
# best_loss = float('inf')
# steps_since_improvement = 0

# for step in range(steps):
# optimizer.zero_grad()

# # Capacitances
# C = torch.exp(log_C)
# Cinv = torch.diag(1.0 / C)

# # System matrix
# A = -Cinv @ G

# # Vectorized computation of exp(A * t) for all times
# At = A.unsqueeze(0) * times[:, None, None] # (T, n_comp, n_comp)
# expAt = torch.matrix_exp(At) # (T, n_comp, n_comp)

# # Kernel prediction: exp(At) @ C^{-1}, take diagonals
# kernels_pred = (expAt @ Cinv).diagonal(dim1=1, dim2=2) # (T, n_comp)

# # Loss: mean squared error vs diagonal target kernels
# loss = torch.mean((kernels_pred - target_kernels) ** 2)
# loss.backward()
# optimizer.step()

# # Track loss
# loss_value = loss.item()
# loss_history.append(loss_value)

# if step % max(1, steps // 10) == 0:
# print(f"Step {step}, Loss {loss_value:.6e}")

# # Early stopping logic
# if early_stopping:
# if loss_value + tol < best_loss:
# best_loss = loss_value
# steps_since_improvement = 0
# else:
# steps_since_improvement += 1
# if steps_since_improvement >= patience:
# print(f"Early stopping at step {step}, loss did not improve for {patience} steps.")
# break

# # Write optimized capacitances back into tree
# C_opt = C.detach().cpu().numpy()
# self._to_tree_c(C_opt)

# # Return fitted kernels and optionally loss history
# kernels_pred_np = kernels_pred.detach().cpu().numpy()
# if return_loss_history:
# return kernels_pred_np, loss_history
# else:

def _fit_res_action(
self, action, mat_feature, vec_target, weight, ca_lim=[], **kwargs
):
self, action, mat_feature, vec_target, weight, ca_lim=[], **kwargs
):
if action == "fit":
res = np.linalg.lstsq(mat_feature, vec_target, rcond=None)
vec_res = res[0].real
Expand Down
Loading