diff --git a/.gitignore b/.gitignore index d662796..d1fa0e6 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,10 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ GDeNet_old/ -old \ No newline at end of file +old +DYMAG_solver_version/lyapunov_results +DYMAG_solver_version/wandb +DYMAG_solver_version/figs +DYMAG_solver_version/lightning_logs +DYMAG_solver_version/data +*~ diff --git a/DYMAG_solver_version/DYMAG.py b/DYMAG_solver_version/DYMAG.py new file mode 100644 index 0000000..5146cc0 --- /dev/null +++ b/DYMAG_solver_version/DYMAG.py @@ -0,0 +1,92 @@ +import os +import sys +import torch +from aggregators import KHopSumAggregator, GraphMomentAggregator +from PDE_layer import PDE_layer + +# make a torch class that applies a PDE_layer, then a KHopSumAggregator, then a GraphMomentAggregator, then flattens the output +# before passing it to a classifier + + +class DYMAG(torch.nn.Module): + def __init__(self, + input_feature_dim, + output_dim, + dynamics='sprott', + n_largest_graph=10, + K = 3, + M = 4, + S = 4, + num_layers=2, + num_lin_layers_after_pde=2, + device='cpu', + ): + super(DYMAG, self).__init__() + + self.input_feature_dim = input_feature_dim + self.output_dim = output_dim + self.K = K + self.M = M + self.S = S + self.dynamics = dynamics + + self.num_layers = num_layers + self.num_lin_layers_after_pde = num_lin_layers_after_pde + self.device = device + + self.pde_layer = PDE_layer(dynamics=dynamics, n_largest_graph=n_largest_graph) + self.k_hop_sum_aggregator = KHopSumAggregator(self.K, self.M) + self.graph_moment_aggregator = GraphMomentAggregator(self.S) + + self.time_points = self.pde_layer.output_times + self.aggregated_size = self.S * self.K * self.M * self.input_feature_dim * len(self.time_points) + + self.lin_layers = torch.nn.ModuleList() + + print('input size is', self.aggregated_size) + layer_size_list = [self.aggregated_size, 64, 48, 32] + for i in range(len(layer_size_list) - 1): + self.lin_layers.append(torch.nn.Linear(layer_size_list[i], layer_size_list[i+1])) + self.classifier = torch.nn.Linear(layer_size_list[-1], output_dim) + + self.nonlin = torch.nn.LeakyReLU() + self.outnonlin = torch.nn.Sigmoid() + + def forward(self, x, edge_index, batch_index): + x = self.pde_layer(x, edge_index, batch_index) + x = self.k_hop_sum_aggregator(x, edge_index) + x = self.graph_moment_aggregator(x, batch_index) + + # keep first axis but flatten all rest + x = x.view(x.size(0), -1) + + for lin_layer in self.lin_layers: + x = lin_layer(x) + x = self.nonlin(x) + + x = self.classifier(x) + # try without sigmoid + return x + #return self.outnonlin(x) + + def reset_parameters(self): + for layer in self.children(): + if hasattr(layer, 'reset_parameters'): + layer.reset_parameters() + + +if __name__ == '__main__': + # test the model + num_nodes = 10 + num_features = 100 + x = torch.randn(num_nodes, num_features) + edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]], dtype=torch.long) + batch_index = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=torch.long) + model = DYMAG(num_features, 1) + print(model(x, edge_index, batch_index).shape) + print(model) + # get the number of trainable parameters for the model + print(sum(p.numel() for p in model.parameters() if p.requires_grad)) + import pdb; pdb.set_trace() + diff --git a/DYMAG_solver_version/DYMAG_pl.py b/DYMAG_solver_version/DYMAG_pl.py new file mode 100644 index 0000000..57da3ae --- /dev/null +++ b/DYMAG_solver_version/DYMAG_pl.py @@ -0,0 +1,100 @@ +import os +import sys +import torch +import pytorch_lightning as pl +from aggregators import KHopSumAggregator, GraphMomentAggregator +from PDE_layer import PDE_layer + +class DYMAG_pl(pl.LightningModule): + def __init__(self, + input_feature_dim, + output_dim, + dynamics='sprott', + K=3, + M=4, + S=4, + num_layers=2, + num_lin_layers_after_pde=2, + learning_rate=1e-3, + custom_device='cpu'): + super(DYMAG_pl, self).__init__() + + self.save_hyperparameters() + self.input_feature_dim = input_feature_dim + self.output_dim = output_dim + self.K = K + self.M = M + self.S = S + self.num_layers = num_layers + self.num_lin_layers_after_pde = num_lin_layers_after_pde + self.custom_device = custom_device + self.learning_rate = learning_rate + self.dynamics = dynamics + + self.pde_layer = PDE_layer(dynamics=dynamics) + self.k_hop_sum_aggregator = KHopSumAggregator(self.K, self.M) + self.graph_moment_aggregator = GraphMomentAggregator(self.S) + + self.time_points = self.pde_layer.output_times + self.aggregated_size = self.S * self.K * self.M * self.input_feature_dim * len(self.time_points) + + self.lin_layers = torch.nn.ModuleList() + print('input size is', self.aggregated_size) + layer_size_list = [self.aggregated_size, 64, 48, 32] + for i in range(len(layer_size_list) - 1): + self.lin_layers.append(torch.nn.Linear(layer_size_list[i], layer_size_list[i+1])) + self.classifier = torch.nn.Linear(layer_size_list[-1], output_dim) + + self.nonlin = torch.nn.LeakyReLU() + + def forward(self, x, edge_index, batch_index): + x = self.pde_layer(x, edge_index, batch_index) + x = self.k_hop_sum_aggregator(x, edge_index) + x = self.graph_moment_aggregator(x, batch_index) + + # keep first axis but flatten all rest + x = x.view(x.size(0), -1) + + for lin_layer in self.lin_layers: + x = lin_layer(x) + x = self.nonlin(x) + + x = self.classifier(x) + return x + + def reset_parameters(self): + for layer in self.children(): + if hasattr(layer, 'reset_parameters'): + layer.reset_parameters() + + def training_step(self, batch, batch_idx): + out = self.forward(batch.x, batch.edge_index, batch.batch) + loss = torch.nn.functional.mse_loss(out, batch.y) + self.log('train_loss', loss) + return loss + + def validation_step(self, batch, batch_idx): + out = self.forward(batch.x, batch.edge_index, batch.batch) + val_loss = torch.nn.functional.mse_loss(out, batch.y) + self.log('val_loss', val_loss) + + def test_step(self, batch, batch_idx): + out = self.forward(batch.x, batch.edge_index, batch.batch) + test_loss = torch.nn.functional.mse_loss(out, batch.y) + self.log('test_loss', test_loss) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + +if __name__ == '__main__': + # Test the model + num_nodes = 10 + num_features = 5 + x = torch.randn(num_nodes, num_features) + edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]], dtype=torch.long) + batch_index = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=torch.long) + model = DYMAG_pl(num_features, 5, heat_derivative_func) + print(model(x, edge_index, batch_index).shape) + print(model) diff --git a/DYMAG_solver_version/PDE_layer.py b/DYMAG_solver_version/PDE_layer.py new file mode 100644 index 0000000..f2d13c1 --- /dev/null +++ b/DYMAG_solver_version/PDE_layer.py @@ -0,0 +1,186 @@ +import torch +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import add_self_loops, degree +from torch_geometric.data import Batch +import torch.nn as nn + +class PDE_layer(MessagePassing): + """ + PDE_layer class represents a layer for solving partial differential equations (PDEs) using message passing. + + Args: + derivative_func (callable): A function that computes the derivative of the PDE. + + Attributes: + step_size (float): The step size for numerical integration. + solver (str): The solver method for solving the PDE. Can be 'euler' or 'rk4'. + sampling_interval (float): The interval at which to sample the solution. + final_t (float): The final time for the integration. + dynamics (str): A description of the time derivative of the PDE. + + Methods: + get_laplacian: Computes the Laplacian of the input data. + forward: Performs the forward pass of the PDE solver. + + """ + + def __init__(self, dynamics='sprott', n_largest_graph=100, b=0.25, **kwargs): + super(PDE_layer, self).__init__(aggr='add') + self.step_size = .01 + self.solver = kwargs.get('solver', 'rk4') + # set up sampling_interval and final_t from kwargs if provided + self.sampling_interval = kwargs.get('sampling_interval', .2) + self.final_t = kwargs.get('final_t', 5) + self.b = b + #self.random_weights = torch.rand((n_largest_graph, n_largest_graph)) - 0.5 + # set random weights to be +1 or -1 + self.random_weights = (torch.randint(0, 2, (n_largest_graph, n_largest_graph)) * 2 - 1) / (n_largest_graph -1)**(1/2) + if dynamics == 'heat': + self.derivative_func = self.heat_dynamic + elif dynamics == 'sprott': + self.derivative_func = self.sprott_dynamic + print(f'Initialized with {dynamics} dynamics') + + self.output_times = torch.arange(0, self.final_t + self.sampling_interval, self.sampling_interval) + + def heat_dynamic(self, x, edge_index, batch): + return -self.get_laplacian(x, edge_index, batch) + + def sprott_dynamic(self, x, edge_index, batch): + """ + Sprott dynamics: + du/dt = -b * u + tanh(sum_ij a_ij * u_j) + """ + # row, col represent source and target nodes of directed edges + row, col = edge_index + + # Create a map from batched node indices to original indices + batch_node_count = (batch == 0).sum().item() # assuming uniform size across graphs in the batch + batch_offset = batch * batch_node_count + + # Adjust row and col indices by subtracting the batch offset + row_adjusted = row - batch_offset[row] + col_adjusted = col - batch_offset[col] + + # Propagate messages using random weights + #weighted_x = self.random_weights[row_adjusted, col_adjusted][:, None] * x[col] + + # Use self.propagate to aggregate the messages + aggregated_message = self.propagate(edge_index, x=x, norm = self.random_weights[row_adjusted, col_adjusted]) + + # Apply Sprott dynamics + dt = -self.b * x + torch.tanh(aggregated_message) + return dt + + + def get_laplacian(self, x, edge_index, batch, normalized=True): + """ + Computes the Laplacian of the input data. + + Args: + x (Tensor): The input data. + edge_index (LongTensor): The edge indices of the graph. + batch (LongTensor): The batch indices of the graph. + normalized (bool): Whether to normalize the Laplacian. + + Returns: + Tensor: The Laplacian of the input data. + + """ + #edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) + row, col = edge_index + if normalized: + deg = degree(row, num_nodes=x.size(0), dtype=torch.float) + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] + else: + norm = None + + adj = self.propagate(edge_index, x=x, norm=norm, size=(x.size(0), x.size(0))) + return x - adj + + def forward(self, x, edge_index, batch): + """ + Performs the forward pass of the PDE solver. + + Args: + x (Tensor): The input data. + edge_index (LongTensor): The edge indices of the graph. + batch (LongTensor): The batch indices of the graph. + + Returns: + Tensor: The solution of the PDE at different time steps. Output has shape [time_steps, num_nodes, num_features] + """ + num_nodes = x.size(0) + batch_size = batch.max().item() + 1 + + if self.solver == 'euler': + num_steps = int(self.final_t / self.step_size) + sampling_interval_steps = int(self.sampling_interval // self.step_size) + num_outputs = (num_steps // sampling_interval_steps) + 1 + + outputs = torch.zeros((int(num_outputs), num_nodes, x.size(1)), device=x.device, requires_grad=False) + outputs[0] = x + + output_idx = 1 + for t_step in range(1, num_steps + 1): + dt = self.derivative_func(x, edge_index, batch) + x = x + self.step_size * dt + if t_step % sampling_interval_steps == 0: + outputs[output_idx] = x + output_idx += 1 + + return outputs + + elif self.solver == 'rk4': + num_steps = int(self.final_t / self.step_size) + sampling_interval_steps = int(self.sampling_interval // self.step_size) + num_outputs = (num_steps // sampling_interval_steps) + 1 + + outputs = torch.zeros((int(num_outputs), num_nodes, x.size(1)), device=x.device, requires_grad=False) + outputs[0] = x + + output_idx = 1 + for t_step in range(1, num_steps + 1): + # Compute an RK4 step + k1 = self.step_size * self.derivative_func(x, edge_index, batch) + k2 = self.step_size * self.derivative_func(x + 0.5 * k1, edge_index, batch) + k3 = self.step_size * self.derivative_func(x + 0.5 * k2, edge_index, batch) + k4 = self.step_size * self.derivative_func(x + k3, edge_index, batch) + + x = x + (1/6) * (k1 + 2 * k2 + 2 * k3 + k4) + + if t_step % sampling_interval_steps == 0: + outputs[output_idx] = x + output_idx += 1 + + return outputs + + def message(self, x_j, norm): + if norm is None: + return x_j + return norm.view(-1, 1) * x_j + + def update(self, aggr_out): + return aggr_out + +if __name__ == '__main__': + # Create dummy input data + x = torch.randn(10, 3) # Shape: [num_nodes, num_features] + edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]], dtype=torch.long) # Shape: [2, num_edges] + # make edge_index undirected + edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1) + + batch = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.long) # Shape: [num_nodes] + + # Create an instance of PDE_layer + pde_layer = PDE_layer() + + # Perform the forward pass + solution = pde_layer.forward(x, edge_index, batch) + + # Print the shape of the solution + print(solution.shape) diff --git a/DYMAG_solver_version/__init__.py b/DYMAG_solver_version/__init__.py new file mode 100644 index 0000000..4316473 --- /dev/null +++ b/DYMAG_solver_version/__init__.py @@ -0,0 +1,2 @@ +from .PDE_layer import PDE_layer, heat_derivative_func +from .aggregators import KHopSumAggregator, GraphMomentAggregator diff --git a/DYMAG_solver_version/aggregators.py b/DYMAG_solver_version/aggregators.py new file mode 100644 index 0000000..5aa176d --- /dev/null +++ b/DYMAG_solver_version/aggregators.py @@ -0,0 +1,98 @@ +# the output of PDE_layer is a tensor of shape (num_outputs, num_nodes, num_features) +# make an aggregator that sums the values in the 1, 2, ..., K hop neighborhoods of each node +# the result should be a tensor of shape (num_outputs, num_nodes, K, num_features) + +import torch +from torch_geometric.utils import k_hop_subgraph +#from torch_geometric.nn import GCNConv + +class KHopSumAggregator(torch.nn.Module): + def __init__(self, K=3, M=4): # M and K referenced from appendix B.2 of the paper + super(KHopSumAggregator, self).__init__() + self.K = K # max number of hops + self.M = M # largest moment for aggregation + + def forward(self, x, edge_index): + # x has shape [num_outputs (t), num_nodes, num_features] + # want output to have shape [num_outputs, num_nodes, K, M, num_features] + num_nodes = edge_index.max().item() + 1 + # Compute the moments for each k-hop sum for each node + k_hop_sums = torch.zeros(x.size(0), x.size(1), self.K, self.M, x.size(2)) + for node_idx in range(x.size(1)): + if node_idx >= num_nodes: #, "Node index is out of bounds." + # this is an isolated node + # add in a self loop + self_loop = torch.tensor([[node_idx], [node_idx]], dtype=edge_index.dtype, device=edge_index.device) + edge_index = torch.cat([edge_index, self_loop], dim=1) + # get the k-hop subgraph for each node + for k in range(1, self.K+1): + subset, _, _, _ = k_hop_subgraph(node_idx, k, edge_index, relabel_nodes=False) + #print(f"Subset for node {node_idx}, k={k}: {subset}") + for m in range(1, self.M+1): + #print(k,m) + k_hop_sums[:, node_idx, k-1, m-1,:] = self._get_moment_sum_aggr(x[:,subset,:], m) + return k_hop_sums + + def _get_moment_sum_aggr(self, x, m): + # sum absolute value of x to the mth power + # x has shape [num_outputs, |N(node)|, num_features] + return x.abs().pow(m).sum(dim=1) + +class GraphMomentAggregator(torch.nn.Module): + def __init__(self, S=4): + super(GraphMomentAggregator, self).__init__() + self.S = S + + def forward(self, x, batch_index): + # x has shape [num_outputs, num_nodes, K, M, num_features] + # batch_index has shape [num_nodes] + # We want output to have shape [num_graphs, num_outputs, S, K, M, num_features] + + # Get the number of graphs from the batch index + num_graphs = batch_index.max().item() + 1 + + # Initialize the output tensor + num_outputs, num_nodes_total, K, M, num_features = x.size() + graph_moments = torch.zeros(num_graphs, num_outputs, self.S, K, M, num_features, device=x.device) + + # Compute moments for each graph in the batch + for g in range(num_graphs): + # Get node indices for the current graph + mask = (batch_index == g) + + # Extract relevant data for the current graph + x_graph = x[:, mask] # Shape [num_outputs, num_nodes_in_graph, K, M, num_features] + + for s in range(1, self.S+1): + # Compute the s-th moment for the current graph + graph_moments[g, :, s-1, :, :, :] = self._get_graph_moment(x_graph, s) + + return graph_moments + + def _get_graph_moment(self, x, s): + # x has shape [num_outputs, num_nodes, K, M, num_features] + # Output should have shape [num_outputs, K, M, num_features] + return x.abs().pow(s).sum(dim=1) + +if __name__ == '__main__': + # Example usage + num_outputs = 2 + num_nodes = 10 + K = 3 + M = 5 + num_features = 4 + batch_size = 5 + S = 4 + + # Create example input + x = torch.rand((num_outputs, num_nodes, K, M, num_features)) + edge_index = torch.randint(0, num_nodes, (2, num_nodes)) + batch_index = torch.randint(0, batch_size, (num_nodes,)) + + # Initialize the aggregator + aggregator = GraphMomentAggregator(S=S) + + # Compute the graph moments + graph_moments = aggregator(x, batch_index) + + print("Graph moments shape:", graph_moments.shape) diff --git a/DYMAG_solver_version/data_module.py b/DYMAG_solver_version/data_module.py new file mode 100644 index 0000000..bcaaf83 --- /dev/null +++ b/DYMAG_solver_version/data_module.py @@ -0,0 +1,141 @@ +import torch +from torch_geometric.data import Data, Dataset +from torch_geometric.loader import DataLoader +import networkx as nx +import numpy as np +from torch_geometric.utils import to_undirected + +class RandomDataset(Dataset): + """ + A dataset class for generating random graphs. + + Args: + random_graph_model (str): The type of random graph model to use. Options are 'er' (Erdos-Renyi) or 'sbm' (Stochastic Block Model). + num_graphs (int): The number of graphs to generate. + graph_params_bounds (dict): A dictionary specifying the bounds for each graph parameter. The keys are the parameter names and the values are tuples (lower_bound, upper_bound). + node_features (str or None): The type of node features to include. Options are 'degree' or None. + + Attributes: + random_graph_model (str): The type of random graph model. + num_graphs (int): The number of graphs to generate. + graph_params_bounds (dict): The bounds for each graph parameter. + node_features (str or None): The type of node features. + graphs (list): The generated graphs. + generating_parameters (list): The generating parameters for each graph. + """ + + def __init__(self, random_graph_model, num_graphs, graph_params_bounds, node_features=None): + super(RandomDataset, self).__init__() + + self.random_graph_model = random_graph_model + self.num_graphs = num_graphs + self.graph_params_bounds = graph_params_bounds + self.node_features = node_features + + self.graphs, self.generating_parameters = self.generate_graphs() + + def generate_graphs(self): + """ + Generates random graphs based on the specified parameters. + + Returns: + graphs (list): The generated graphs. + generating_parameters (list): The generating parameters for each graph. + """ + graphs = [] + generating_parameters = [] + for _ in range(self.num_graphs): + graph_params = self.sample_graph_params() + if self.random_graph_model == 'er': + graph = nx.erdos_renyi_graph(n=int(graph_params['n']), p=graph_params['p']) + elif self.random_graph_model == 'sbm': + sizes = [int(graph_params['n'] / graph_params['k'])] * int(graph_params['k']) + # modify the size of the last block to make sure the total number of nodes is n + sizes[-1] += graph_params['n'] - sum(sizes) + p_matrix = np.full((len(sizes), len(sizes)), graph_params['p_out']) + np.fill_diagonal(p_matrix, graph_params['p_in']) + graph = nx.stochastic_block_model(sizes, p_matrix) + else: + raise ValueError("Invalid random graph model") + graphs.append(graph) + generating_parameters.append(graph_params) + return graphs, generating_parameters + + def sample_graph_params(self): + """ + Samples random graph parameters based on the specified bounds. + + Returns: + graph_params (dict): The sampled graph parameters. + """ + graph_params = {} + for param, bounds in self.graph_params_bounds.items(): + lower_bound, upper_bound = bounds + if param in ['n', 'k']: + value = np.random.randint(lower_bound, upper_bound + 1) + else: + value = np.random.uniform(lower_bound, upper_bound) + graph_params[param] = value + return graph_params + + def __len__(self): + """ + Returns the number of graphs in the dataset. + + Returns: + int: The number of graphs. + """ + return self.num_graphs + + def __getitem__(self, idx): + """ + Returns the graph and its corresponding parameters at the given index. + + Args: + idx (int): The index of the graph. + + Returns: + Data: The graph data object containing the node features, target values, and edge indices. + """ + graph = self.graphs[idx] + params = self.generating_parameters[idx] + if self.node_features == 'degree': + x = np.array([graph.degree[node] for node in graph.nodes()]).reshape(-1, 1) + if self.node_features == 'random': + n_random_signal = self.graph_params_bounds.get('n_random_signal', 5) + x = np.random.rand(graph.number_of_nodes(), n_random_signal) + else: + x = np.eye(graph.number_of_nodes()) + y = np.array([params[param] for param in self.graph_params_bounds.keys() if param in params and param != 'n']).reshape(-1, 1) + + edge_index = torch.tensor(list(graph.edges)).t().contiguous() + edge_index = to_undirected(edge_index) + + return Data(x=torch.tensor(x, dtype=torch.float), y=torch.tensor(y, dtype=torch.float), edge_index=edge_index) + +if __name__ == '__main__': + random_graph_model = 'sbm' # 'er' or 'sbm' + num_graphs = 10 + graph_params_bounds = { + 'n': (50, 50), + 'p': (0.1, 0.5), + 'k': (1, 5), + 'p_in': (0.5, 0.9), + 'p_out': (0.05, 0.2) + } + + graph_params_bounds = { + 'n': (50, 50), + 'k': (1, 5), + 'p_in': (0.5, 0.9), + 'p_out': (0.05, 0.2) + } + + dataset = RandomDataset(random_graph_model, num_graphs, graph_params_bounds) + dataloader = DataLoader(dataset, batch_size = 5, shuffle=True) + print(dataset[0]) + print(dataset[0].y) + assert dataset[0].is_undirected() + + for batch in dataloader: + print(batch) diff --git a/DYMAG_solver_version/lyapunov.py b/DYMAG_solver_version/lyapunov.py new file mode 100644 index 0000000..6a23249 --- /dev/null +++ b/DYMAG_solver_version/lyapunov.py @@ -0,0 +1,101 @@ +import nolds +import numpy as np +import matplotlib.pyplot as plt +import networkx as nx +import torch +from PDE_layer import PDE_layer +import phate +import time +import argparse +from data_module import RandomDataset +from torch_geometric.loader import DataLoader +import pandas as pd +import os +import multiprocessing as mp +import numpy as np + +def compute_lyapunov_for_node(traj): + # Calculate the maximum Lyapunov exponent for a given trajectory + return max(nolds.lyap_e(traj)) + +def compute_lyapunov_for_feature(x, num_nodes, traj_ind): + lyap_max = 0 + # Calculate Lyapunov exponents for each node in the feature + for node_ind in range(num_nodes): + traj = x[:, node_ind, traj_ind] + l = compute_lyapunov_for_node(traj) + if l > lyap_max: + lyap_max = l + return lyap_max + +def measure_lyapunov(x): + # Assume that x is outputs + x = x.detach().cpu().numpy() + num_steps, n_nodes, num_features = x.shape + + # Prepare arguments for parallel processing + args = [(x, n_nodes, traj_ind) for traj_ind in range(num_features)] + + # Use multiprocessing to parallelize across features + with mp.Pool() as pool: + lyap_max_list = pool.starmap(compute_lyapunov_for_feature, args) + + return lyap_max_list + + + +if __name__ == '__main__': + argparser = argparse.ArgumentParser() + argparser.add_argument('--dynamics', type=str, default='sprott', help='Dynamics to simulate') + argparser.add_argument('--n_node_list', type=int, nargs='+', default=[25, 50, 100, 317], help='List of number of nodes') + argparser.add_argument('--num_reps_per_n_node', type=int, default=5, help='Number of repetitions per number of nodes') + argparser.add_argument('--num_graphs', type=int, default=10, help='Number of graphs') + argparser.add_argument('--signal_type', type=str, default='dirac', help='Type of signal to use') + argparser.add_argument('--sampling_interval', type=float, default=0.2, help='Sampling interval') + argparser.add_argument('--final_t', type=float, default=20, help='Final time') + argparser.add_argument('--step_size', type=float, default=0.01, help='Step size') + args = argparser.parse_args() + + results = pd.DataFrame(columns=['n_nodes','rep','graph_ind','lyap_mean', 'lyap_min', 'lyap_max', 'lyap_std', 'time']) + for n_nodes in args.n_node_list: + print('Number of nodes:', n_nodes) + for rep in range(args.num_reps_per_n_node): + print('Repetition:', rep) + # get a random integer seed + seed = rep + torch.manual_seed(seed) + np.random.seed(seed) + + graph_params_bounds = {'n': (n_nodes, n_nodes), 'p': (.1, .5)} + # get dataset + dataset = RandomDataset(random_graph_model='er', + num_graphs=args.num_graphs, + graph_params_bounds=graph_params_bounds, + node_features=args.signal_type) + + dataloader = DataLoader(dataset, batch_size=1, shuffle=True) + pde_layer = PDE_layer(dynamics=args.dynamics, n_largest_graph = n_nodes,sampling_interval = args.sampling_interval, final_t = args.final_t, step_size = args.step_size) + + for graph_ind, data in enumerate(dataloader): + x = data.x + edge_index = data.edge_index + batch = data.batch + start_time = time.time() + outputs = pde_layer(x, edge_index, batch) + time_elapsed = time.time() - start_time + lyap_max_list = measure_lyapunov(outputs) + lyap_mean = np.mean(lyap_max_list) + lyap_min = np.min(lyap_max_list) + lyap_max = np.max(lyap_max_list) + lyap_std = np.std(lyap_max_list) + new_row = {'n_nodes': n_nodes, 'rep': rep, 'graph_ind': graph_ind, 'lyap_mean': lyap_mean, 'lyap_min': lyap_min, 'lyap_max': lyap_max, 'lyap_std': lyap_std, 'time': time_elapsed} + results = pd.concat([results, pd.DataFrame([new_row])], ignore_index=True) + print(f'Results for graph {graph_ind}:') + print(f'Lyapunov mean: {lyap_mean}, Lyapunov min: {lyap_min}, Lyapunov max: {lyap_max}, Lyapunov std: {lyap_std}, Time: {time_elapsed}') + + save_string = f'lyapunov_results_{args.dynamics}_{args.signal_type}_{args.sampling_interval}_{args.final_t}_{args.step_size}.csv' + results_dir = 'lyapunov_results' + if not os.path.exists(results_dir): + os.makedirs(results_dir) + save_string = os.path.join(results_dir, save_string) + results.to_csv(save_string, index=False) diff --git a/DYMAG_solver_version/main.py b/DYMAG_solver_version/main.py new file mode 100644 index 0000000..47de781 --- /dev/null +++ b/DYMAG_solver_version/main.py @@ -0,0 +1,79 @@ +from data_module import RandomDataset +from torch_geometric.loader import DataLoader +import numpy as np +import torch +import networkx as nx +from PDE_layer import PDE_layer +from DYMAG import DYMAG +import argparse + +if __name__ == '__main__': + args = argparse.ArgumentParser(description='Run DYMAG on a RandomDataset of Erdos-Renyi graphs') + args.add_argument('--num_graphs', type=int, default=25, help='Number of training graphs to generate') + args.add_argument('--num_graphs_test', type=int, default=5, help='Number of test graphs to generate') + args.add_argument('--n_nodes', type=int, default=100, help='Number of nodes in each graph') + args.add_argument('--p_min', type=float, default=0.1, help='Minimum edge probability') + args.add_argument('--p_max', type=float, default=0.5, help='Maximum edge probability') + args.add_argument('--device', type=str, default='cpu', help='Device to run the model on (cpu or cuda)') + args.add_argument('--dynamic', type=str, default='sprott', help='Dynamics to model') + args = args.parse_args() + + # run DYMAG on a RandomDataset of Erdos-Renyi graphs + # set a seed for reproducability + torch.manual_seed(0) + # set up the dataset + num_graphs = args.num_graphs + num_graphs_test = args.num_graphs_test + n_nodes = args.n_nodes + graph_params_bounds = {'n': (n_nodes, n_nodes), 'p': (args.p_min, args.p_max)} + + train_dataset = RandomDataset(random_graph_model='er', num_graphs=num_graphs, graph_params_bounds=graph_params_bounds) + dataloader = DataLoader(train_dataset, batch_size=5, shuffle=True) + + test_dataset = RandomDataset(random_graph_model='er', num_graphs=num_graphs_test, graph_params_bounds=graph_params_bounds) + + model = DYMAG(input_feature_dim=n_nodes, output_dim=1, dynamics=args.dynamic, n_largest_graph=n_nodes, device=args.device) + + # set up the optimizer and loss function + optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) + criterion = torch.nn.MSELoss() + + num_epochs = 100 + # training loop + for epoch in range(num_epochs): + total_loss = 0 + + for data in dataloader: + optimizer.zero_grad() + + # forward pass + output = model(data.x, data.edge_index, data.batch) + + # compute loss + loss = criterion(output, data.y) + + # backward pass and optimization + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # print average loss for the epoch + avg_loss = total_loss / len(dataloader) + print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}") + + # run DYMAG on the test dataset and compute the average loss + total_loss = 0 + test_dataloader = DataLoader(test_dataset, batch_size=5, shuffle=False) + + for data in test_dataloader: + x = data.x + edge_index = data.edge_index + batch_index = data.batch + # forward pass + output = model(x, edge_index, batch_index) + loss = criterion(output, data.y) + total_loss += loss.item() + + avg_loss = total_loss / len(test_dataloader) + print(f"Average Loss on Test Dataset: {avg_loss}") diff --git a/DYMAG_solver_version/outputs_plot_heat.png b/DYMAG_solver_version/outputs_plot_heat.png new file mode 100644 index 0000000..b39e047 Binary files /dev/null and b/DYMAG_solver_version/outputs_plot_heat.png differ diff --git a/DYMAG_solver_version/phate_plot_heat.png b/DYMAG_solver_version/phate_plot_heat.png new file mode 100644 index 0000000..13b1819 Binary files /dev/null and b/DYMAG_solver_version/phate_plot_heat.png differ diff --git a/DYMAG_solver_version/see_edge_index.py b/DYMAG_solver_version/see_edge_index.py new file mode 100644 index 0000000..e07c061 --- /dev/null +++ b/DYMAG_solver_version/see_edge_index.py @@ -0,0 +1,27 @@ +import torch +from torch_geometric.datasets import TUDataset + +dataset = TUDataset(root='data/TUDataset', name='MUTAG') + +print() +print(f'Dataset: {dataset}:') +print('====================') +print(f'Number of graphs: {len(dataset)}') +print(f'Number of features: {dataset.num_features}') +print(f'Number of classes: {dataset.num_classes}') + +data = dataset[0] # Get the first graph object. + +print() +print(data) +print('=============================================================') + +# Gather some statistics about the first graph. +print(f'Number of nodes: {data.num_nodes}') +print(f'Number of edges: {data.num_edges}') +print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}') +print(f'Has isolated nodes: {data.has_isolated_nodes()}') +print(f'Has self-loops: {data.has_self_loops()}') +print(f'Is undirected: {data.is_undirected()}') + +print(data.edge_index) \ No newline at end of file diff --git a/DYMAG_solver_version/sprott_small_perturbation.py b/DYMAG_solver_version/sprott_small_perturbation.py new file mode 100644 index 0000000..83534f2 --- /dev/null +++ b/DYMAG_solver_version/sprott_small_perturbation.py @@ -0,0 +1,158 @@ +import nolds +import numpy as np +import matplotlib.pyplot as plt +import networkx as nx +import torch +from PDE_layer import PDE_layer +import phate +import time +import argparse +from data_module import RandomDataset +from torch_geometric.loader import DataLoader +from torch_geometric.data import Data, Dataset +import pandas as pd +import os +import multiprocessing as mp +import numpy as np + +def compute_lyapunov_for_node(traj): + # Calculate the maximum Lyapunov exponent for a given trajectory + return nolds.lyap_r(traj) + +def compute_lyapunov_for_feature(x, num_nodes, traj_ind): + lyap_max = 0 + # Calculate Lyapunov exponents for each node in the feature + for node_ind in range(num_nodes): + traj = x[:, node_ind, traj_ind] + l = compute_lyapunov_for_node(traj) + if l > lyap_max: + lyap_max = l + return lyap_max + +def measure_lyapunov(x): + # Assume that x is outputs + x = x.detach().cpu().numpy() + num_steps, n_nodes, num_features = x.shape + + # Prepare arguments for parallel processing + args = [(x, n_nodes, traj_ind) for traj_ind in range(num_features)] + + # Use multiprocessing to parallelize across features + with mp.Pool() as pool: + lyap_max_list = pool.starmap(compute_lyapunov_for_feature, args) + + return lyap_max_list + +def get_time_diffs(x1, x2): + # x1 and x2 have shape (num_steps, n_nodes, num_features) + diff = x1-x2 + # want the norm to be matrix norm, resulting in final shape of num_steps + norm = np.linalg.norm(diff, axis=(1,2)) + return norm + +if __name__ == '__main__': + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_nodes', type=int, default=25, help='Number of nodes') + argparser.add_argument('--num_reps_per_n_node', type=int, default=5, help='Number of repetitions per number of nodes') + argparser.add_argument('--num_graphs', type=int, default=10, help='Number of graphs') + argparser.add_argument('--signal_type', type=str, default='dirac', help='Type of signal to use') + argparser.add_argument('--sampling_interval', type=float, default=0.2, help='Sampling interval') + argparser.add_argument('--final_t', type=float, default=20, help='Final time') + argparser.add_argument('--step_size', type=float, default=0.01, help='Step size') + argparser.add_argument('--edge_addition', type=int, default=2, help='Number of edges to add to the graph') + args = argparser.parse_args() + + # generate one graph + n_nodes = args.n_nodes + seed = 24 + torch.manual_seed(seed) + np.random.seed(seed) + signal_type = args.signal_type + + graph_params_bounds = {'n': (n_nodes, n_nodes), 'p': (.1, .5)} + # get dataset + data_unperturbed = RandomDataset(random_graph_model='er', + num_graphs=1, + graph_params_bounds=graph_params_bounds, + node_features=signal_type) + + # create a data object with the same graph but add a random edge + G = data_unperturbed[0] + data_unperturbed = data_unperturbed[0] + edge_index = G.edge_index + edge_addition = args.edge_addition + added_edges = 0 + while added_edges < edge_addition: + node0, node1 = torch.randint(0, n_nodes, (2,)) + edge_in_index = (node0 == edge_index[0]) & (node1 == edge_index[1]) + if edge_in_index.any(): + continue + edge_index = torch.cat([edge_index, torch.tensor([[node0, node1], [node1, node0]])], dim=1) + print(f'added edge between {node0} and {node1}') + added_edges += 1 + data_perturbed = Data(x=G.x, edge_index=edge_index) + + pde_layer_sprott = PDE_layer(dynamics='sprott', n_largest_graph = n_nodes,sampling_interval = args.sampling_interval, final_t = args.final_t, step_size = args.step_size) + pde_layer_heat = PDE_layer(dynamics='heat', n_largest_graph = n_nodes,sampling_interval = args.sampling_interval, final_t = args.final_t, step_size = args.step_size) + + batch = torch.tensor([0 for _ in range(n_nodes)], dtype=torch.long) + sprott_perturbed = pde_layer_sprott(data_perturbed.x, data_perturbed.edge_index, batch) + sprott_unperturbed = pde_layer_sprott(data_unperturbed.x, data_perturbed.edge_index, batch) + heat_perturbed = pde_layer_heat(data_perturbed.x, data_perturbed.edge_index, batch) + heat_unperturbed = pde_layer_heat(data_unperturbed.x, data_perturbed.edge_index, batch) + + # get lyapunov exponents + print('getting lyapunov for sprott perturbed') + lyap_sprott_perturbed = measure_lyapunov(sprott_perturbed) + print('getting lyapunov for sprott unperturbed') + lyap_sprott_unperturbed = measure_lyapunov(sprott_unperturbed) + print('getting lyapunov for heat perturbed') + lyap_heat_perturbed = measure_lyapunov(heat_perturbed) + print('getting lyapunov for heat unperturbed') + lyap_heat_unperturbed = measure_lyapunov(heat_unperturbed) + + # output has shape (num_steps, n_nodes, num_features) + deltas_sprott = get_time_diffs(sprott_perturbed, sprott_unperturbed) + deltas_heat = get_time_diffs(heat_perturbed, heat_unperturbed) + + # print out deltas sprott and heat every 5 steps + print('sprott deltas every second:') + deltas_sprott_sparse = deltas_sprott[::5] + print(deltas_sprott_sparse) + print('heat deltas every second:') + deltas_heat_sparse = deltas_heat[::5] + print(deltas_heat_sparse) + + + # plot the results + plt.figure() + plt.plot(deltas_sprott, label='Sprott') + plt.plot(deltas_heat, label='Heat') + plt.legend() + plt.xlabel('Time step') + plt.ylabel('Norm of difference') + plt.title('Difference between perturbed and unperturbed graphs') + plt.show() + + # make a 2x2 subplot and plot the solution over time at node 0 feature 0 + plt.figure() + plt.subplot(2,2,1) + plt.plot(sprott_perturbed[:, 0, 0]) + plt.title('Sprott perturbed') + plt.subplot(2,2,2) + plt.plot(sprott_unperturbed[:, 0, 0]) + plt.title('Sprott unperturbed') + plt.subplot(2,2,3) + plt.plot(heat_perturbed[:, 0, 0]) + plt.title('Heat perturbed') + plt.subplot(2,2,4) + plt.plot(heat_unperturbed[:,0,0]) + plt.title('Heat unperturbed') + plt.show() + + # print out lyapunov exponents + print('Lyapunov exponents:') + print('Sprott perturbed:', lyap_sprott_perturbed) + print('Sprott unperturbed:', lyap_sprott_unperturbed) + print('Heat perturbed:', lyap_heat_perturbed) + print('Heat unperturbed:', lyap_heat_unperturbed) diff --git a/DYMAG_solver_version/test_PDE_layer.py b/DYMAG_solver_version/test_PDE_layer.py new file mode 100644 index 0000000..e7548cc --- /dev/null +++ b/DYMAG_solver_version/test_PDE_layer.py @@ -0,0 +1,66 @@ +import torch +import networkx as nx +from PDE_layer import PDE_layer, heat_derivative_func + +def test_PDE_layer(): + # Create a sample graph + G = nx.cycle_graph(5) + edge_index = torch.tensor(list(G.edges)).t().contiguous() + x = torch.randn(5, 3) + batch = torch.tensor([0, 0, 0, 0, 0]) + + # Create a PDE_layer instance + pde_layer = PDE_layer(derivative_func=heat_derivative_func) + + # Perform forward pass + outputs = pde_layer(x, edge_index, batch) + + # Assert the shape of the outputs + assert outputs.shape == (26, 5, 3) + + # Assert the values of the outputs + assert torch.allclose(outputs[0], x) + #assert torch.allclose(outputs[-1], x) + print("PDE_layer test passed!") + +def test_PDE_layer_batch(): + # Create multiple sample graphs + G1 = nx.cycle_graph(5) + G2 = nx.complete_graph(3) + G3 = nx.star_graph(4) + + # Convert graphs to edge index format + edge_index1 = torch.tensor(list(G1.edges)).t().contiguous() + edge_index2 = torch.tensor(list(G2.edges)).t().contiguous() + edge_index3 = torch.tensor(list(G3.edges)).t().contiguous() + + # Create node features for each graph + x1 = torch.randn(5, 3) + x2 = torch.randn(3, 3) + x3 = torch.randn(4, 3) + + # Create batch tensor + batch = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2]) + + # Create a PDE_layer instance + pde_layer = PDE_layer(derivative_func=heat_derivative_func) + + # Batch the graphs together + batched_edge_index = torch.cat([edge_index1, edge_index2, edge_index3], dim=1) + batched_x = torch.cat([x1, x2, x3], dim=0) + + # Perform forward pass + outputs = pde_layer(batched_x, batched_edge_index, batch) + # Assert the shape of the outputs + # output takes on shape (num_steps, num_nodes, num_features) + assert outputs.shape == (26, 12, 3) + + # Assert the values of the outputs + #assert torch.allclose(outputs[0], x1) + #assert torch.allclose(outputs[5], x2) + #assert torch.allclose(outputs[15], x3) + print("PDE_layer batch test passed!") + + +test_PDE_layer() +test_PDE_layer_batch() \ No newline at end of file diff --git a/DYMAG_solver_version/test_data_module.py b/DYMAG_solver_version/test_data_module.py new file mode 100644 index 0000000..0ab2d6d --- /dev/null +++ b/DYMAG_solver_version/test_data_module.py @@ -0,0 +1,50 @@ +import unittest +from data_module import RandomDataset +import networkx as nx +import torch_geometric + +class TestRandomDataset(unittest.TestCase): + def test_generate_graphs(self): + # Create a RandomDataset instance + dataset = RandomDataset(random_graph_model='er', num_graphs=5, graph_params_bounds={'n': (10, 20), 'p': (0.1, 0.5)}) + + # Check if the number of generated graphs matches the specified number of graphs + self.assertEqual(len(dataset.graphs), 5) + + # Check if the generating parameters are correctly stored + self.assertEqual(len(dataset.generating_parameters), 5) + + # Check if each generated graph is of type networkx.Graph + for graph in dataset.graphs: + self.assertIsInstance(graph, nx.Graph) + + def test_sample_graph_params(self): + # Create a RandomDataset instance + dataset = RandomDataset(random_graph_model='er', num_graphs=1, graph_params_bounds={'n': (10, 20), 'p': (0.1, 0.5)}) + + # Sample graph parameters + graph_params = dataset.sample_graph_params() + + # Check if the sampled graph parameters are within the specified bounds + self.assertGreaterEqual(graph_params['n'], 10) + self.assertLessEqual(graph_params['n'], 20) + self.assertGreaterEqual(graph_params['p'], 0.1) + self.assertLessEqual(graph_params['p'], 0.5) + + def test_getitem(self): + # Create a RandomDataset instance + dataset = RandomDataset(random_graph_model='er', num_graphs=1, graph_params_bounds={'n': (10, 20), 'p': (0.1, 0.5)}) + + # Get the graph data object at index 0 + data = dataset[0] + + # Check if the returned object is of type torch_geometric.data.Data + self.assertIsInstance(data, torch_geometric.data.Data) + + # Check if the node features, target values, and edge indices are correctly set + self.assertEqual(data.x.shape, (data.num_nodes, 1)) + self.assertEqual(data.y.shape, (data.num_nodes - 1, 1)) + self.assertEqual(data.edge_index.shape[0], 2) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/DYMAG_solver_version/train_pl.py b/DYMAG_solver_version/train_pl.py new file mode 100644 index 0000000..5c7f139 --- /dev/null +++ b/DYMAG_solver_version/train_pl.py @@ -0,0 +1,47 @@ +import pytorch_lightning as pl +import torch +from torch_geometric.loader import DataLoader +from data_module import RandomDataset +from PDE_layer import heat_derivative_func +from DYMAG_pl import DYMAG_pl +import wandb + +# Initialize wandb +wandb.init(project='DYMAG_project') + +# Set a seed for reproducibility +torch.manual_seed(0) + +# Set up the dataset +num_graphs = 25 +num_graphs_test = 5 +num_graphs_validation = 5 +n_nodes = 10 +graph_params_bounds = {'n': (n_nodes, n_nodes), 'p': (0.1, 0.5)} + +train_dataset = RandomDataset(random_graph_model='er', num_graphs=num_graphs, graph_params_bounds=graph_params_bounds) +dataloader = DataLoader(train_dataset, batch_size=5, shuffle=True) + +test_dataset = RandomDataset(random_graph_model='er', num_graphs=num_graphs_test, graph_params_bounds=graph_params_bounds) +test_dataloader = DataLoader(test_dataset, batch_size=5, shuffle=False) + +validation_dataset = RandomDataset(random_graph_model='er', num_graphs=num_graphs_validation, graph_params_bounds=graph_params_bounds) +validation_dataloader = DataLoader(validation_dataset, batch_size=5, shuffle=False) + + +# Initialize the model +model = DYMAG_pl(input_feature_dim=n_nodes, output_dim=1, derivative_func=heat_derivative_func) + +# Define the trainer +trainer = pl.Trainer( + max_epochs=100, + logger=pl.loggers.WandbLogger(), + ) + +# Fit the model +trainer.fit(model, dataloader, validation_dataloader) + +trainer.test(model, test_dataloader) + +# Close the wandb run +wandb.finish() diff --git a/DYMAG_solver_version/visualize_trajectories.py b/DYMAG_solver_version/visualize_trajectories.py new file mode 100644 index 0000000..e23f566 --- /dev/null +++ b/DYMAG_solver_version/visualize_trajectories.py @@ -0,0 +1,209 @@ +import tphate +import numpy as np +import matplotlib.pyplot as plt +import networkx as nx +import torch +from PDE_layer import PDE_layer +import phate + +if __name__ == '__main__': + # get a random integer seed + seed = np.random.randint(0, 1000) + torch.manual_seed(seed) + np.random.seed(seed) + + # Create an ER graph + n_nodes = 25 + p = 0.2 + num_traj = 5 + dynamics = 'heat' + small_perturb = True + if dynamics == 'sprott': + phate_knn = 5 + final_t = 100 + sampling_interval = .5 + else: + phate_knn = 20 + final_t = 7 + sampling_interval = 0.1 + + visualization = 'phate' + # print out dynamics and visualization and other parameters + print(f'Dynamics: {dynamics}') + print(f'Visualization: {visualization}') + print(f'Number of nodes: {n_nodes}') + print(f'Number of trajectories: {num_traj}') + print(f'Final time: {final_t}') + print(f'Sampling interval: {sampling_interval}') + + G = nx.erdos_renyi_graph(n_nodes, p) + edge_index = torch.tensor(list(G.edges)).t().contiguous() + edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1) + if small_perturb: + x = torch.randn(n_nodes, 1) * 2 + # replicate this n_traj times + x = x.repeat(1, num_traj) + # add noise + x = x + torch.randn(n_nodes, num_traj) + else: + x = torch.randn(n_nodes, num_traj) + # zero center the data + x = x - x.mean(dim=0) + batch = torch.tensor([0 for _ in range(n_nodes)], dtype=torch.long) + + # Create a PDE_layer instance + pde_layer = PDE_layer(dynamics=dynamics, n_largest_graph = n_nodes,sampling_interval = sampling_interval, final_t = final_t, step_size = 0.01) + + # Perform forward pass + outputs = pde_layer(x, edge_index, batch) + # outputs has shape (num_steps, num_nodes, num_features) + # plot the outputs at node 0 across all time steps for each trajectory + # do this in the same plot + for traj_ind in range(num_traj): + plt.plot(outputs[:, 0, traj_ind], label=f'Trajectory {traj_ind}') + plt.legend() + # set the x-axis label to be the time step + plt.xlabel('Forward recorded step') + plt.ylabel('Node 0 value') + plt.title('dynamics: ' + dynamics) + # save the figure + plt.savefig(f'figs/outputs_plot_{dynamics}_{final_t}_{sampling_interval}_{phate_knn}_{seed}.png') + # close + plt.close() + + + + if visualization == 'tphate': + # outputs has shape (num_steps, num_nodes, num_features) + # rearrange to (num_features, num_steps, num_nodes) + outputs = outputs.permute(2, 0, 1).detach().numpy() + import pdb; pdb.set_trace() + # loop through each trajectory + for traj_ind in range(num_traj): + tphate_op = tphate.TPHATE(n_components=3, n_jobs=-1, verbose=0) + data_tphate = tphate_op.fit_transform(outputs[traj_ind]) + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.scatter(data_tphate[:, 0], data_tphate[:, 1], data_tphate[:, 2]) + plt.show() + elif visualization == 'phate': + # outputs has shape (num_steps, num_nodes, num_features) + total_n_steps = outputs.shape[0] + # want to sparsify outputs at the end steps if dynamics are heat + if dynamics == 'heat': + outputs_time_mask = torch.ones(total_n_steps, dtype=int) + midpt = total_n_steps // 2 + for i in range(midpt, total_n_steps): + if i % 5 == 0: + outputs_time_mask[i] = 0 + for i in range(midpt + midpt//2 - 4, total_n_steps): + if np.random.rand() < 0.8: + outputs_time_mask[i] = 0 + outputs = outputs[outputs_time_mask == 1] + + num_steps, num_nodes, num_features = outputs.shape + print(f'num_steps: {num_steps}, num_nodes: {num_nodes}, num_features: {num_features}, total_n_steps: {total_n_steps}') + # rearrange to (num_steps * num_features, num_nodes) + outputs = outputs.permute(2,0,1).reshape(num_features * num_steps, num_nodes).detach().numpy() + + # this is stuipd but i'm lazy + time_tracker = torch.zeros((num_steps, num_nodes, num_features), dtype=int) + # set each value to the time step index + for i in range(num_steps): + time_tracker[i] = i + time_tracker = time_tracker.permute(2,0,1).reshape(num_features * num_steps, num_nodes).detach().numpy() + time_tracker = time_tracker[:,0] + # make something to track num_features + traj_tracker = torch.zeros((num_steps, num_nodes, num_features), dtype=int) + # set each value to the num_features index + for i in range(num_features): + traj_tracker[:,:,i] = i + traj_tracker = traj_tracker.permute(2,0,1).reshape(num_features * num_steps, num_nodes).detach().numpy() + traj_tracker = traj_tracker[:,0] + + # # Perform PHATE + phate_op = phate.PHATE(n_components=3, n_jobs=-1, verbose=1, knn = phate_knn) + data_phate = phate_op.fit_transform(outputs) + + # # Plot the results + # fig = plt.figure() + # ax = fig.add_subplot(111, projection='3d') + # # color by the original time step + # ax.scatter(data_phate[:, 0], data_phate[:, 1], data_phate[:, 2], c=time_tracker) + # # add color bar + # #ax.scatter(data_phate[:, 0], data_phate[:, 1], data_phate[:, 2]) + # plt.show() + + + # Plot the results + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + + # Define a colormap + colormap = plt.cm.viridis + + # Normalize time_tracker for color mapping + norm = plt.Normalize(vmin=time_tracker.min(), vmax=time_tracker.max()) + + # Marker list + marker_list = ['o', 'x', 's', 'D', '^'] + + # Plot each trajectory with different markers and colors + for traj_ind in range(num_features): + traj_mask = traj_tracker == traj_ind + colors = colormap(norm(time_tracker[traj_mask])) + ax.scatter(data_phate[traj_mask, 0], data_phate[traj_mask, 1], data_phate[traj_mask, 2], + label=f'IC {traj_ind}', c=colors, marker=marker_list[traj_ind]) + + # Add legend + plt.legend() + + # Add a color bar for the time + mappable = plt.cm.ScalarMappable(cmap=colormap, norm=norm) + mappable.set_array(time_tracker) + cbar = plt.colorbar(mappable, ax=ax) + # turn off color bar ticks + cbar.set_ticks([]) + cbar.set_label('Time') + # turn off axis ticks + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_zticks([]) + plt.title(f'PHATE of {dynamics} dynamics') + + # Save the plot + plt.savefig(f'figs/phate_plot_{dynamics}_{final_t}_{sampling_interval}_{phate_knn}_{seed}_{small_perturb}.png') + # Show the plot + plt.show() + plt.close() + print(f'saved to figs/phate_plot_{dynamics}_{final_t}_{sampling_interval}_{phate_knn}_{seed}_{small_perturb}.png') + + # Plot the results but color by the trajectory + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + + color_dict = {0: 'r', 1: 'g', 2: 'b', 3: 'c', 4: 'm'} + + # Marker list + marker_list = ['o', 'x', 's', 'D', '^'] + + # Plot each trajectory with different markers and colors + for traj_ind in range(num_features): + traj_mask = traj_tracker == traj_ind + ax.scatter(data_phate[traj_mask, 0], data_phate[traj_mask, 1], data_phate[traj_mask, 2], + label=f'IC {traj_ind}', c=color_dict[traj_ind], marker=marker_list[traj_ind]) + + # Add legend + plt.legend() + + + # turn off axis ticks + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_zticks([]) + plt.title(f'PHATE of {dynamics} dynamics') + + # Save the plot + plt.savefig(f'figs/phate_plot_traj_color_{dynamics}_{final_t}_{sampling_interval}_{phate_knn}_{seed}_{small_perturb}.png') + # Show the plot + plt.show()