diff --git a/src/neat/trees/compartmenttree.py b/src/neat/trees/compartmenttree.py index eedf16f4..0737f4cd 100755 --- a/src/neat/trees/compartmenttree.py +++ b/src/neat/trees/compartmenttree.py @@ -1801,9 +1801,119 @@ def compute_c(self, alphas, phimat, weights=None, tau_eps=5.0): self._to_tree_c(c_vec) + # def fit_capacitances_with_matrix_exp( + # self, times: np.ndarray, target_kernels: np.ndarray, + # lr=1e-2, steps=1000, device="cpu", + # return_loss_history=False, early_stopping=False, patience=50, tol=1e-6 + # ): + # """ + # Fit compartment capacitances to match input resistance kernels + # from GreensTreeTime using matrix exponentials. + + # Parameters + # ---------- + # times : np.ndarray, shape (T,) + # Time points at which kernels are evaluated. + # target_kernels : np.ndarray, shape (T, n_comp) + # Target diagonal kernels from GreensTreeTime + # (response at i from injection at i). + # lr : float + # Learning rate for optimizer. + # steps : int + # Maximum number of optimization steps. + # device : str + # PyTorch device ("cpu" or "cuda"). + # return_loss_history : bool, default False + # If True, also return the list of loss values over optimization. + # early_stopping : bool, default False + # If True, stop optimization early if the loss does not improve. + # patience : int, default 50 + # Number of steps to wait for improvement before stopping. + # tol : float, default 1e-6 + # Minimum improvement in loss to reset patience. + + # Returns + # ------- + # kernels_pred : np.ndarray, shape (T, n_comp) + # Predicted fitted diagonal kernels. + # loss_history : list of float, optional + # Loss values at each optimization step (only if return_loss_history=True). + # """ + # import torch + # n_comp = len(self) + + # # Convert inputs to torch tensors + # times = torch.tensor(times, dtype=torch.float32, device=device) + # target_kernels = torch.tensor(target_kernels, dtype=torch.float32, device=device) + + # # Initial capacitance vector from tree + # c_init = self._to_c_vec() # numpy array (n_comp,) + # log_C = torch.nn.Parameter( + # torch.log(torch.tensor(c_init, dtype=torch.float32, device=device)) + # ) + # optimizer = torch.optim.Adam([log_C], lr=lr) + + # # Use NEAT's conductance matrix function + # G_np = self.calc_conductance_matrix() # numpy array (n_comp, n_comp) + # G = torch.tensor(G_np, dtype=torch.float32, device=device) + + # loss_history = [] + # best_loss = float('inf') + # steps_since_improvement = 0 + + # for step in range(steps): + # optimizer.zero_grad() + + # # Capacitances + # C = torch.exp(log_C) + # Cinv = torch.diag(1.0 / C) + + # # System matrix + # A = -Cinv @ G + + # # Vectorized computation of exp(A * t) for all times + # At = A.unsqueeze(0) * times[:, None, None] # (T, n_comp, n_comp) + # expAt = torch.matrix_exp(At) # (T, n_comp, n_comp) + + # # Kernel prediction: exp(At) @ C^{-1}, take diagonals + # kernels_pred = (expAt @ Cinv).diagonal(dim1=1, dim2=2) # (T, n_comp) + + # # Loss: mean squared error vs diagonal target kernels + # loss = torch.mean((kernels_pred - target_kernels) ** 2) + # loss.backward() + # optimizer.step() + + # # Track loss + # loss_value = loss.item() + # loss_history.append(loss_value) + + # if step % max(1, steps // 10) == 0: + # print(f"Step {step}, Loss {loss_value:.6e}") + + # # Early stopping logic + # if early_stopping: + # if loss_value + tol < best_loss: + # best_loss = loss_value + # steps_since_improvement = 0 + # else: + # steps_since_improvement += 1 + # if steps_since_improvement >= patience: + # print(f"Early stopping at step {step}, loss did not improve for {patience} steps.") + # break + + # # Write optimized capacitances back into tree + # C_opt = C.detach().cpu().numpy() + # self._to_tree_c(C_opt) + + # # Return fitted kernels and optionally loss history + # kernels_pred_np = kernels_pred.detach().cpu().numpy() + # if return_loss_history: + # return kernels_pred_np, loss_history + # else: + def _fit_res_action( - self, action, mat_feature, vec_target, weight, ca_lim=[], **kwargs - ): + self, action, mat_feature, vec_target, weight, ca_lim=[], **kwargs + ): if action == "fit": res = np.linalg.lstsq(mat_feature, vec_target, rcond=None) vec_res = res[0].real