diff --git a/README.md b/README.md index 2770c60..c149c7c 100644 --- a/README.md +++ b/README.md @@ -44,9 +44,37 @@ Or alternatively run this command: Please note there is another package called spectra which is not related to this tool. Spectrae (which stands for spectral evaluation) implements the spectral framework for model evaluation. +## Definition of terms + +This work and GitHub repository use terms related to the **spectral framework for model evaluation**. Below is a quick refresher on these key concepts. + +### **Spectral Property** +Every dataset has an underlying property that, as it changes, causes model performance to decrease. This is referred to as the **spectral property**. + +However, **not every property qualifies as a spectral property**. +For example: +- When predicting protein structure, the performance of a protein folding model does **not** change based on the number of **M** amino acids in a sequence. +- Instead, model performance **does** change based on **structural similarity**—this is an example of a **spectral property**. + +### **Spectral Property Graph (SPG)** +For a given dataset, a **spectral property graph (SPG)** is defined as: +- **Nodes**: Samples in the dataset. +- **Edges**: Connections between samples that share a spectral property. + +Every SPG is defined by a flattened adjacency matrix, this saves memory and allowed SPECTRA to utilize GPUs to speed up computation. + +### **Spectral Parameter** +The **spectral parameter** can be thought of as a **sparsification probability**. + +When SPECTRA runs on an SPG: +1. It selects a random node. +2. It decides whether to **delete edges** with a certain probability—this probability is the **spectral parameter**. +3. The closer the spectral parameter is to **1**, the **stricter** the splits generated by SPECTRA will be. + + ## How to use spectra -### Step 1: Define the spectral property, cross-split overlap, and the spectra dataset wrapper +### Step 1: Define the spectral property and the spectra dataset wrapper To run spectra you must first define important two abstract classes, Spectra and SpectraDataset. @@ -86,7 +114,7 @@ class [Name]_Dataset(SpectraDataset): pass ``` -Spectra implements the user definition of the spectra property and cross split overlap. +Spectra implements the user definition of the spectra property. ```python @@ -103,52 +131,62 @@ class [Name]_spectra(spectra): ''' return similarity - def cross_split_overlap(self, train, test): - ''' - Define this function to return the overlap between a list of train and test samples. +``` +### Step 2: Initialize SPECTRA and calculate the flattened adjacency matrix - Example: Average pairwise similarity between train and test set protein sequences. +1. **Initialize SPECTRA** + - Initially, pass in no spectral property graph. - ''' - +2. **Pass SPECTRA and dataset into the `Spectra_Property_Graph_Constructor`** + - Additional arguments: + - **`num_chunks`**: If your dataset is very large, you can split up the construction into chunks to allow multiple jobs to compute similarity. This parameter controls the number of chunks. + - **`binary`**: If `True`, the similarity returns either `0` or `1`; otherwise, it returns a floating-point number. - return cross_split_overlap -``` -### Step 2: Initialize SPECTRA and precalculate pairwise spectral properties +3. **Call `create_adjacency_matrix`** + - This function takes in the **chunk number** to calculate: + - If `num_chunks = 0`, the pairwise similarity is calculated in one go, so the input to `create_adjacency_matrix` should be `0`. + - If `num_chunks = 10`, the input should be the chunk number you want to calculate (e.g., `0` to `9`). + +4. **Combine the adjacency matrices** + - Call `combine_adjacency_matrices()` in the graph constructor to combine all the adjacency matrices into a single matrix. -Initialize SPECTRA, passing in True or False to the binary argument if the spectral property returns a binary or continuous value. Then precalculate the pairwise spectral properties. ```python -init_spectra = [name]_spectra([name]_Dataset, binary = True) -init_spectra.pre_calculate_spectra_properties([name]) +from spectrae import Spectral_Property_Graph_Constructor +spectra = [name]_spectra([name]_Dataset, spg=None) +construct_spg = Spectra_Property_Graph_Constructor(spectra, [name]_Dataset, num_chunks = 0, binary = [False/True]) +construct_spg.create_adjacency_matrix(0) +construct_spg.combine_adjacency_matrices() ``` -### Step 3: Initialize SPECTRA and precalculate pairwise spectral properties -Generate SPECTRA splits. The ```generate_spectra_splits``` function takes in 4 important parameters: -1. ```number_repeats```: the number of times to rerun SPECTRA for the same spectral parameter, the number of repeats must equal the number of seeds as each rerun uses a different seed. -2. ```random_seed```: the random seeds used by each SPECTRA rerun, [42, 44] indicates two reruns the first of which will use a random seed of 42, the second will use 44. -3. ```spectra_parameters```: the spectral parameters to run on, they must range from 0 to 1 and be string formatted to the correct number of significant figures to avoid float formatting errors. -4. ```force_reconstruct```: True to force the model to regenerate SPECTRA splits even if they have already been generated. +### Step 3: Generate SPECTRA Splits -```python -spectra_parameters = {'number_repeats': 3, - 'random_seed': [42, 44, 46], - 'spectral_parameters': ["{:.2f}".format(i) for i in np.arange(0, 1.05, 0.05)], - 'force_reconstruct': True, - } +1. **Initialize the Spectral Property Graph** + - Pass in the flattened adjacency matrix you just generated to the Spectral_Property_Graph to create the spectral property graph. -init_spectra.generate_spectra_splits(**spectra_parameters) +2. **Recreate SPECTRA** + - Use the SPECTRA dataset along with the created spectral property graph to reinstantiate SPECTRA. +3. **Call `generate_spectra_split`** with the following arguments: + - **`spectra_param`**: The spectral parameter to run, must be between `0` and `1` (inclusive). + - **`degree_choosing`**: Only applicable to binary graphs; optimizes the algorithm by prioritizing deletion of nodes with a low degree first. + - **`num_splits`**: Number of splits to generate (usually `20`, which translates to spectral parameters between `0` and `1` in intervals of `0.05`). + - **`path_to_save`**: Location to store generated SPECTRA splits. + - **`debug_mode`**: Controls the amount of information to output. + +```python +spg = Spectral_Property_Graph(FlattenedAdjacency("flattened_adjacency_matrix.pt")) +spectra = [name]_spectra(dataset, spg) +spectra.generate_spectra_split(spectra_param, degree_choosing = [True/False], num_splits = [int], path_to_save="", debug_mode = [True/False]) ``` ### Step 4: Investigate generated SPECTRA splits -After SPECTRA has completed, the user should investigate the generated splits. Specifically ensuring that on average the cross-split overlap decreases as the spectral parameter increases. This can be achieved by using ```return_all_split_stats``` to gather the cross_split_overlap, train size, and test size of each generated split. Example outputs can be seen in the tutorials. +After SPECTRA has completed, the user should investigate the generated splits. Specifically ensuring that on average the cross-split overlap decreases as the spectral parameter increases. This can be achieved by using ```return_all_split_stats``` to gather the cross_split_overlap, train size, and test size of each generated split. Example outputs can be seen in the tutorials. The path_to_save should be the same path you used in the previous step. ```python -stats = init_spectra.return_all_split_stats() -plt.scatter(stats['SPECTRA_parameter'], stats['cross_split_overlap']) +spectra.return_all_split_stats(show_progress = True, path_to_save = save_path) ``` ## Spectra tutorials @@ -163,7 +201,7 @@ If there are any other tutorials of interest feel free to raise an issue! ## Background -SPECTRA is from a preprint, for more information on the preprint, the method behind SPECTRA, and the initials studies conducted with SPECTRA, check out the paper folder. +SPECTRA is [published](https://rdcu.be/d2D0z) in Nature Machine Intelligence. For more code about the method behind SPECTRA and the initials studies conducted with SPECTRA, check out the paper folder. ## Discussion and Development @@ -185,7 +223,7 @@ All development discussions take place on GitHub in this repo in the issue track 2. *I have a foundation model that is pre-trained on a large amount of data. It is not feasible to do pairwise calculations of SPECTRA properties. How can I use SPECTRA?* - It is still possible to run SPECTRA on the foundation model (FM) and the evaluation dataset. You would use SPECTRA on the evaluation dataset then train and evaluate the foundation model on each SPECTRA split (either through linear probing, fine-tuning, or any other strategy) to calculate the AUSPC. Then you would determine the cross-split overlap between the pre-training dataset and the evaluation dataset. You would repeat this for multiple evaluation datasets, until you could plot FM AUSPC versus cross-split overlap to the evaluation dataset. For more details on what this would look like check out the [publication](https://www.biorxiv.org/content/10.1101/2024.02.25.581982v1), specifically section 5 of the results section. If there is large interest in this FAQ I can release a tutorial on this, just raise an issue! + It is still possible to run SPECTRA on the foundation model (FM) and the evaluation dataset. You would use SPECTRA on the evaluation dataset then train and evaluate the foundation model on each SPECTRA split (either through linear probing, fine-tuning, or any other strategy) to calculate the AUSPC. Then you would determine the cross-split overlap between the pre-training dataset and the evaluation dataset. You would repeat this for multiple evaluation datasets, until you could plot FM AUSPC versus cross-split overlap to the evaluation dataset. For more details on what this would look like check out the [publication](https://rdcu.be/d2D0z), specifically section 5 of the results section. If there is large interest in this FAQ I can release a tutorial on this, just raise an issue! 3. *I have a foundation model that is pre-trained on a large amount of data and **I do not have access to the pre-training data**. How can I use SPECTRA?* @@ -193,7 +231,7 @@ All development discussions take place on GitHub in this repo in the issue track 4. *SPECTRA takes a long time to run is it worth it?* - The pairwise spectral property comparison is computationally expensive, but only needs to be done once. Generated SPECTRA splits are important resources that should be released to the public so others can utlilize them without spending resources. For more details on the runtime of the method check out the [publication](https://www.biorxiv.org/content/10.1101/2024.02.25.581982v1), specifically section 6 of the results section. The computation can be sped up with cpu cores, which is a feature that will be released. + The pairwise spectral property comparison is computationally expensive, but only needs to be done once. Generated SPECTRA splits are important resources that should be released to the public so others can utlilize them without spending resources. For more details on the runtime of the method check out the [publication](https://rdcu.be/d2D0z), specifically section 6 of the results section. The computation can be sped up with cpu cores, which is a feature that will be released. If there are any other questions please raise them in the issues and I can address them. I'll keep adding to the FAQ as common questions begin to surface. @@ -206,15 +244,20 @@ SPECTRA is under the MIT license found in the LICENSE file in this GitHub reposi Please cite this paper when referring to SPECTRA. ``` -@article {spectra, - author = {Yasha Ektefaie and Andrew Shen and Daria Bykova and Maximillian Marin and Marinka Zitnik and Maha R Farhat}, - title = {Evaluating generalizability of artificial intelligence models for molecular datasets}, - elocation-id = {2024.02.25.581982}, - year = {2024}, - doi = {10.1101/2024.02.25.581982}, - URL = {https://www.biorxiv.org/content/early/2024/02/28/2024.02.25.581982}, - eprint = {https://www.biorxiv.org/content/early/2024/02/28/2024.02.25.581982.full.pdf}, - journal = {bioRxiv} +@ARTICLE{Ektefaie2024, + title = "Evaluating generalizability of artificial intelligence models + for molecular datasets", + author = "Ektefaie, Yasha and Shen, Andrew and Bykova, Daria and Marin, + Maximillian G and Zitnik, Marinka and Farhat, Maha", + journal = "Nat. Mach. Intell.", + publisher = "Springer Science and Business Media LLC", + volume = 6, + number = 12, + pages = "1512--1524", + month = dec, + year = 2024, + copyright = "https://www.springernature.com/gp/researchers/text-and-data-mining", + language = "en" } ``` diff --git a/spectrae/__init__.py b/spectrae/__init__.py index a06475f..c7cab25 100644 --- a/spectrae/__init__.py +++ b/spectrae/__init__.py @@ -1,2 +1,3 @@ -from .spectra import Spectra -from .dataset import SpectraDataset \ No newline at end of file +from .spectra import Spectra, Spectra_Property_Graph_Constructor +from .dataset import SpectraDataset +from .utils import Spectral_Property_Graph, FlattenedAdjacency, plot_split_stats diff --git a/spectrae/dataset.py b/spectrae/dataset.py index 2d683f7..03edba7 100644 --- a/spectrae/dataset.py +++ b/spectrae/dataset.py @@ -1,38 +1,41 @@ from abc import ABC, abstractmethod +from typing import List, Dict class SpectraDataset(ABC): def __init__(self, input_file, name): self.input_file = input_file self.name = name - self.samples = self.parse(input_file) - - @abstractmethod - def sample_to_index(self, idx): - """ - Given a sample, return the data idx - """ - pass - + self.sample_to_index = self.parse(input_file) + self.samples = list(self.sample_to_index.keys()) + self.samples.sort() + self.index_map = {value: idx for idx, value in enumerate(self.samples)} @abstractmethod - def parse(self, input_file): + def parse(self, input_file: str) -> Dict: """ - Given a dataset file, parse the dataset file. - Make sure there are only unique entries! + Given a dataset file, parse the dataset file to return a dictionary mapping a sample ID to the data """ - pass + raise NotImplementedError("Must implement parse method to use SpectraDataset, see documentation for more information") - @abstractmethod def __len__(self): """ Return the length of the dataset """ - pass + return len(self.samples) - @abstractmethod def __getitem__(self, idx): """ Given a dataset idx, return the element at that index """ - pass \ No newline at end of file + if isinstance(idx, int): + return self.sample_to_index[self.samples[idx]] + return self.sample_to_index[idx] + + def index(self, value): + """ + Given a value, return the index of that value + """ + if value not in self.index_map: + raise ValueError(f"{value} not in the dataset") + return self.index_map[value] \ No newline at end of file diff --git a/spectrae/independent_set_algo.py b/spectrae/independent_set_algo.py index 7d0ba35..f8c0805 100644 --- a/spectrae/independent_set_algo.py +++ b/spectrae/independent_set_algo.py @@ -1,83 +1,113 @@ -import random -import networkx as nx -from .utils import is_clique, connected_components, is_integer -from scipy import stats +import random import numpy as np +from tqdm import tqdm +import torch +from .utils import FlattenedAdjacency, Spectral_Property_Graph, cross_split_overlap -def run_independent_set(spectral_parameter, input_G, seed = None, - debug=False, distribution = None, binary = True): - total_deleted = 0 - independent_set = [] - - if seed is not None: - random.seed(seed) +def run_independent_set(spectral_parameter: int, + input_G: Spectral_Property_Graph, + seed: int = 42, + binary: bool = True, + minimum: int = None, + degree_choosing: bool = False, + num_splits: int = None, + debug_mode: bool = False): - G = input_G.copy() - if binary: - #First check if any connected component of the graph is a clique, if so, add it as one unit to the independent set - components = list(connected_components(G)) - deleted = 0 - for i, component in enumerate(components): - subgraph = G.subgraph(component) - if is_clique(subgraph): - print(f"Component {i} is too densly connected, adding samples as a single unit to independent set and deleting them from the graph") - independent_set.append(list(subgraph.nodes())) - G.remove_nodes_from(subgraph.nodes()) + + total_num_deleted = 0 + independent_set = [] + random.seed(seed) + + n = input_G.num_nodes() + indices_to_scan = list(range(n)) + if spectral_parameter == 0: + return indices_to_scan + pbar = tqdm(total = len(indices_to_scan)) + + #Trying a non-percentile approach + #Note this assumes there are 20 + if not binary: + if num_splits is None: + raise Exception("Num splits must be specified for non-binary graphs, see documentation for more information") + #Higher spectral parameter means more nodes deleted, so lower threshold + threshold = (1-spectral_parameter)*(input_G.max() - input_G.min()) + else: + threshold = 0 + if debug_mode: + print(f"Threshold is {threshold}") + indices_deleted = [] + full_indices_deleted = [] + + expected_number_delete = int(n * spectral_parameter) + if debug_mode: + print(expected_number_delete) + min_degree_node = input_G.get_minimum_degree_node() + num_deleted_in_iteration = 0 + + while len(indices_to_scan) > 0: + if debug_mode: + print(f'Num possibly deleted {len(indices_deleted)}, num actually deleted {num_deleted_in_iteration}, number of nodes left to consider {len(indices_to_scan)}') + num_deleted_in_iteration = 0 + indices_deleted = [] + if degree_choosing: + if len(full_indices_deleted) > 0: + chosen_node, _ = min_degree_node.send(full_indices_deleted) else: - for node in list(subgraph.nodes()): - if subgraph.degree(node) == len(subgraph.nodes()) - 1: - deleted += 1 - G.remove_node(node) + chosen_node, _ = next(min_degree_node) + else: + chosen_node = random.sample(indices_to_scan, 1)[0] + + indices_to_scan.remove(chosen_node) + full_indices_deleted.append(chosen_node) + total_num_deleted += 1 + num_deleted_in_iteration += 1 + + to_iterate = indices_to_scan[:] - print(f"Deleted {deleted} nodes from the graph since they were connected to all other nodes") + indices_to_gather = [] - iterations = 0 - - while not nx.is_empty(G): - chosen_node = random.sample(list(G.nodes()), 1)[0] + for index in to_iterate: + indices_to_gather.append((chosen_node, index)) + + values = input_G.get_weights(indices_to_gather) + indices_deleted.extend(list(torch.tensor(to_iterate).cuda()[values > threshold].cpu().numpy())) + + indices_deleted = list(set(indices_deleted)) + indices_to_scan = set(indices_to_scan) - independent_set.append(chosen_node) - neighbors = G.neighbors(chosen_node) - neighbors_to_delete = [] + # if len(indices_deleted) > expected_number_delete: + # indices_deleted = [chosen_node] + # total_num_deleted += 1 + # else: - for neighbor in neighbors: - if not binary: - if spectral_parameter == 1.0: - neighbors_to_delete.append(neighbor) - else: - edge_weight = G[chosen_node][neighbor]['weight'] - if distribution is None: - raise Exception("Distribution must be provided if binary is set to False, must precompute similarities") - if random.random() < spectral_parameter and (1-spectral_parameter)*100 < stats.percentileofscore(distribution, edge_weight): - neighbors_to_delete.append(neighbor) + independent_set.append(chosen_node) + for i in indices_deleted: + if binary: + if random.random() < spectral_parameter: + indices_to_scan.remove(i) + total_num_deleted += 1 + num_deleted_in_iteration += 1 + full_indices_deleted.append(i) else: - if spectral_parameter == 1.0: - neighbors_to_delete.append(neighbor) - elif spectral_parameter != 0.0: - if random.random() < spectral_parameter: - neighbors_to_delete.append(neighbor) - - if debug: - print(f"Iteration {iterations} Stats") - print(f"Deleted {len(neighbors_to_delete)} nodes from {G.degree(chosen_node)} neighbors of node {chosen_node}") - total_deleted += len(neighbors_to_delete) + indices_to_scan.remove(i) + total_num_deleted += 1 + num_deleted_in_iteration += 1 + full_indices_deleted.append(i) - for neighbor in neighbors_to_delete: - G.remove_node(neighbor) - - if chosen_node not in neighbors_to_delete: - G.remove_node(chosen_node) - - iterations += 1 - - for node in list(G.nodes()): - #Append the nodes left to G - independent_set.append(node) + if minimum is not None: + if n - total_num_deleted <= minimum - len(independent_set): + independent_set.extend(indices_to_scan) + return independent_set + + indices_deleted.append(chosen_node) - if debug: - print(f"{len(input_G.nodes())} nodes in the original graph") - print(f"Total deleted {total_deleted}") - print(f"{len(independent_set)} nodes in the independent set") + indices_to_scan = list(indices_to_scan) + pbar.update(num_deleted_in_iteration) + + if len(indices_to_scan) != n - len(full_indices_deleted): + raise Exception("Length of indices to scan is not equal to n (num nodes) - len(full_indices_deleted), logic is not met") + + pbar.close() - return independent_set \ No newline at end of file + return independent_set diff --git a/spectrae/spectra.py b/spectrae/spectra.py index 45bbae6..3a45906 100644 --- a/spectrae/spectra.py +++ b/spectrae/spectra.py @@ -3,22 +3,26 @@ from sklearn.model_selection import train_test_split import os import pickle -from .utils import is_clique, connected_components, is_integer +from .utils import Spectral_Property_Graph, plot_split_stats +from .dataset import SpectraDataset import numpy as np from tqdm import tqdm import pandas as pd from abc import ABC, abstractmethod +import pickle +import torch +from typing import List, Tuple, Optional, Dict class Spectra(ABC): - def __init__(self, dataset, - binary = True): + def __init__(self, dataset: SpectraDataset, + spg: Spectral_Property_Graph): #SPECTRA properties should be a function that given two samples in your dataset, returns whether they are similar or not #Cross split overlap should be a function that given two lists of samples, returns the overlap between the two lists self.dataset = dataset - self.SPG = None - self.spectra_properties_loaded = None - self.binary = binary + self.SPG = spg + if self.SPG is not None: + self.binary = self.SPG.binary @abstractmethod def spectra_properties(self, sample_one, sample_two): @@ -29,92 +33,73 @@ def spectra_properties(self, sample_one, sample_two): """ pass - @abstractmethod - def cross_split_overlap(self, train, test): - """ - Define the cross split overlap between two lists of samples. - Ideally should be a number between 0 and 1 that defines the overlap between the two lists of samples - - """ - pass - - def construct_spectra_graph(self, force_reconstruct = False): - if self.SPG is not None: - return self.SPG - elif os.path.exists(f"{self.dataset.name}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") and not force_reconstruct: - print("Loading spectral property graph") - self.SPG = nx.read_gexf(f"{self.dataset.name}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") - self.return_spectra_graph_stats() - return self.SPG - else: - self.SPG = nx.Graph() - if self.spectra_properties_loaded is not None: - for row in tqdm(self.spectra_properties_loaded.itertuples(), - total = len(self.spectra_properties_loaded)): - if row[3]: - self.SPG.add_edge(row[1], row[2], weight = row[3]) - else: - for i in tqdm(range(len(self.dataset))): - for j in range(i+1, len(self.dataset)): - if self.spectra_properties_loaded is not None: - weight = self.spectra_properties_loaded[(self.spectra_properties_loaded[0] == i) & (self.spectra_properties_loaded[1] == j)][2].values[0] - else: - weight = self.spectra_properties(self.dataset[i], self.dataset[j]) - if weight: - self.SPG.add_edge(i, j, weight = weight) - - self.return_spectra_graph_stats() + def cross_split_overlap(self, + split: List[int], + split_two: Optional[List[int]] = None, + chunksize: int = 10000000, + show_progress: bool = False) -> Tuple[float, float, float]: + + def calculate_overlap(index_to_gather): + if self.SPG.binary: + num_similar = 0 - if self.binary: - #Check to make sure SPG is not fully connected - if is_clique(self.SPG): - print("SPG is fully connected") - raise Exception("The SPG is fully connected, cannot run SPECTRA, all samples are similar to each other") + if show_progress: + index_to_gather = tqdm(index_to_gather, total = len(split)) else: - print("SPG is not fully connected") - components = list(connected_components(self.SPG)) - all_fully_connected = True - for i, component in enumerate(components): - subgraph = self.SPG.subgraph(component) - if is_clique(subgraph): - print(f"Component {i} is fully connected, all samples are similar to each other") - #raise Exception("The SPG is fully connected, cannot run SPECTRA") - else: - all_fully_connected = False - print(f"Component {i} is not fully connected") - if all_fully_connected: - raise Exception("All SPG sub components are fully connected, cannot run SPECTRA, all samples are similar to each other") - - if not os.path.exists(f"{self.dataset.name}_spectral_property_graphs"): - os.makedirs(f"{self.dataset.name}_spectral_property_graphs") - - nx.write_gexf( self.SPG, f"{self.dataset.name}_spectral_property_graphs/{self.dataset.name}_SPECTRA_property_graph.gexf") + index_to_gather = index_to_gather + + for compare_list in index_to_gather: + if self.SPG.get_weights(compare_list).sum() > 0: + num_similar += 1 - return self.SPG + return num_similar/(len(split)), num_similar, len(split) + else: + mean_val = 0.0 + std_val = 0.0 + max_val = float('-inf') + min_val = float('inf') + count = 0 + for compare_list in index_to_gather: + weights = self.SPG.get_weights(compare_list) + mean_val += weights.sum() + std_val += (weights ** 2).sum() + if weights.max() > max_val: + max_val = weights.max() + if weights.min() < min_val: + min_val = weights.min() + count += len(weights) + + if count > 100000000: + break + mean_val /= count + std_val = (std_val / count - mean_val ** 2) ** 0.5 + return mean_val, std_val, max_val, min_val + + def generate_indices(split, split_two): + if split_two is not None: + for i in range(len(split)): + to_compare = [] + for j in range(len(split_two)): + to_compare.append((split[i], split_two[j])) + yield to_compare + else: + for i in range(len(split)): + to_compare = [] + for j in range(i+1, len(split)): + to_compare.append((split[i], split[j])) + yield to_compare + + index_to_gather = generate_indices(split, split_two) + + return calculate_overlap(index_to_gather) def return_spectra_graph_stats(self): - if self.SPG is None: - self.construct_spectra_graph() + num_nodes, num_edges, density = self.SPG.stats() print("Stats for SPECTRA property graph (SPG)") - print(f"Number of nodes: {self.SPG.number_of_nodes()}") - print(f"Number of edges: {self.SPG.number_of_edges()}") - num_connected_components = nx.number_connected_components(self.SPG) - print(f"Number of connected components: {num_connected_components}\n\n") - if num_connected_components > 1: - print("Connected component stats") - components = list(connected_components(self.SPG)) - densities = [] - for i, component in enumerate(components): - subgraph = self.SPG.subgraph(component) - print(f"Component {i} has {subgraph.number_of_nodes()} nodes and {subgraph.number_of_edges()} edges") - print(f"Density of component {i}: {nx.density(subgraph)}") - densities.append(nx.density(subgraph)) - if is_clique(subgraph): - print(f"Component {i} is fully connected, all samples are similar to each other") - else: - print(f"Component {i} is not fully connected") - - print(f"Average density {np.mean(densities)}") + print(f"Number of nodes: {num_nodes}") + print(f"Number of edges: {num_edges}") + print(f"Density of SPG: {density}") + return num_nodes, num_edges, density def spectra_train_test_split(self, nodes, test_size, random_state): train = [] @@ -133,76 +118,190 @@ def spectra_train_test_split(self, nodes, test_size, random_state): return train, test def get_samples(self, nodes): - return [self.dataset[int(i)] for i in nodes] + return [self.dataset[i] for i in nodes] + + def get_sample_indices(self, samples): + return [self.dataset.index(i) for i in samples] def generate_spectra_split(self, - spectral_parameter, - random_seed, - test_size = 0.2): + spectral_parameter: float, + random_seed: int = 42, + test_size: float = 0.2, + degree_choosing: bool = False, + minimum: int = None, + path_to_save: str = None, + num_splits: int = None, + debug_mode: bool = False): - spectral_property_graph = self.SPG print(f"Generating SPECTRA split for spectral parameter {spectral_parameter} and dataset {self.dataset.name}") - result = run_independent_set(spectral_parameter, spectral_property_graph, - seed = random_seed, - distribution = self.spectra_properties_loaded[2], - binary = self.binary) - if len(result) <= 5: - return None, None, None + result = run_independent_set(spectral_parameter, self.SPG, + seed = random_seed, + binary = self.binary, + minimum = minimum, + degree_choosing = degree_choosing, + num_splits = num_splits, + debug_mode = debug_mode) + + if len(result) <= 10: + raise Exception("Independent set has less than 10 samples, cannot generate split") print(f"Number of samples in independent set: {len(result)}") train, test = self.spectra_train_test_split(result, test_size=test_size, random_state=random_seed) - print(f"Train size: {len(train)}\tTest size: {len(test)}") - cross_split_overlap = self.cross_split_overlap(self.get_samples(train), self.get_samples(test)) - print(f"Cross split overlap: {cross_split_overlap}\n\n\n") - stats = {'SPECTRA_parameter': spectral_parameter, 'train_size': len(train), 'test_size': len(test), 'cross_split_overlap': cross_split_overlap} - return train, test, stats + stats = self.get_stats(train, test, spectral_parameter) + if path_to_save is None: + return train, test, stats + else: + i = 0 + if not os.path.exists(f"{path_to_save}/SP_{spectral_parameter}_{i}"): + os.makedirs(f"{path_to_save}/SP_{spectral_parameter}_{i}") + + pickle.dump(train, open(f"{path_to_save}/SP_{spectral_parameter}_{i}/train.pkl", "wb")) + pickle.dump([self.dataset.samples[i] for i in train], open(f"{path_to_save}/SP_{spectral_parameter}_{i}/train_IDs.pkl", "wb")) + + pickle.dump(test, open(f"{path_to_save}/SP_{spectral_parameter}_{i}/test.pkl", "wb")) + pickle.dump([self.dataset.samples[i] for i in test], open(f"{path_to_save}/SP_{spectral_parameter}_{i}/test_IDs.pkl", "wb")) + + pickle.dump(stats, open(f"{path_to_save}/SP_{spectral_parameter}_{i}/stats.pkl", "wb")) + + def get_stats(self, train: List, + test: List, + spectral_parameter: float, + chunksize: int = 10000000, + show_progress: bool = False, + sample_values: bool = False): + + """ + Computes statistics for the given train and test splits. + + Args: + train (List): A list of training sample IDs or sample indices. (see sample_values) + test (List): A list of test sample IDs or sample indices. (see sample_values) + spectral_parameter (float): The spectral parameter used for computation. + chunksize (int, optional): The size of chunks to process at a time. Default is 10,000,000. Decrease if you get a OOM error. + show_progress (bool, optional): Whether to show progress during computation. Default is False. + sample_values (bool, optional): True if you are passing sample IDs, False if you are passing sample indices. Default is False. + + Returns: + Dict[str, Any]: A dictionary containing the computed statistics. The keys and values depend on whether the data is binary or not. + If not binary: + - 'SPECTRA_parameter' (float): The spectral parameter used. + - 'train_size' (int): The size of the training set. + - 'test_size' (int): The size of the testing set. + - 'cross_split_overlap' (float): The cross-split overlap value. + - 'std_css' (float): The standard deviation of the cross-split similarity. + - 'max_css' (float): The maximum cross-split similarity. + - 'min_css' (float): The minimum cross-split similarity. + If binary: + - 'SPECTRA_parameter' (float): The spectral parameter used. + - 'train_size' (int): The size of the training set. + - 'test_size' (int): The size of the testing set. + - 'cross_split_overlap' (float): The cross-split overlap value. + - 'num_similar' (int): The number of similar items. + - 'num_total' (int): The total number of items. + + Raises: + ValueError: If the train or test lists are empty. + + """ + + train_size = len(train) + test_size = len(test) + + if sample_values: + train = self.get_sample_indices(train) + test = self.get_sample_indices(test) + + if not self.binary: + cross_split_overlap, std_css, max_css, min_css = self.cross_split_overlap(train, test, chunksize, show_progress) + stats = {'SPECTRA_parameter': spectral_parameter, + 'train_size': train_size, + 'test_size': test_size, + 'cross_split_overlap': cross_split_overlap, + 'std_css': std_css, + 'max_css': max_css, + 'min_css': min_css} + else: + cross_split_overlap, num_similar, num_total = self.cross_split_overlap(train, test, chunksize, show_progress) + stats = {'SPECTRA_parameter': spectral_parameter, + 'train_size': train_size, + 'test_size': test_size, + 'cross_split_overlap': cross_split_overlap, + 'num_similar': num_similar, + 'num_total': num_total} + return stats def generate_spectra_splits(self, - spectral_parameters, - number_repeats, - random_seed, - test_size = 0.2, - force_reconstruct = False): + spectral_parameters: List[float], + number_repeats: int, + random_seed: List[float], + test_size: float = 0.2, + degree_choosing: bool = False, + minimum: int = None, + force_reconstruct: bool = False, + path_to_save: str = None): #Random seed is a list of random seeds for each number name = self.dataset.name - self.construct_spectra_graph(force_reconstruct = force_reconstruct) if self.binary: - if nx.density(self.SPG) >= 0.4: + if self.SPG.get_density() >= 0.4: raise Exception("Density of SPG is greater than 0.4, SPECTRA will not work as your dataset is too similar to itself. Please check your dataset and SPECTRA properties.") - if not os.path.exists(f"{name}_SPECTRA_splits"): - os.makedirs(f"{name}_SPECTRA_splits") - if not os.path.exists(f"{name}_spectral_property_graphs"): - os.makedirs(f"{name}_spectral_property_graphs") + if path_to_save is None: + path_to_save = f"{name}_SPECTRA_splits" + + if not os.path.exists(path_to_save): + os.makedirs(path_to_save) splits = [] for spectral_parameter in spectral_parameters: for i in range(number_repeats): - if os.path.exists(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}") and not force_reconstruct: + if os.path.exists(f"{path_to_save}/SP_{spectral_parameter}_{i}") and not force_reconstruct: print(f"Folder SP_{spectral_parameter}_{i} already exists. Skipping") - elif force_reconstruct or not os.path.exists(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}"): - train, test, stats = self.generate_spectra_split(float(spectral_parameter), random_seed[i], test_size) + elif force_reconstruct or not os.path.exists(f"{path_to_save}/SP_{spectral_parameter}_{i}"): + train, test, stats = self.generate_spectra_split(float(spectral_parameter), random_seed[i], test_size, degree_choosing, minimum) if train is not None: - if not os.path.exists(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}"): - os.makedirs(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}") + if not os.path.exists(f"{path_to_save}/SP_{spectral_parameter}_{i}"): + os.makedirs(f"{path_to_save}_SPECTRA_splits/SP_{spectral_parameter}_{i}") - pickle.dump(train, open(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}/train.pkl", "wb")) - pickle.dump(test, open(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}/test.pkl", "wb")) - pickle.dump(stats, open(f"{name}_SPECTRA_splits/SP_{spectral_parameter}_{i}/stats.pkl", "wb")) + pickle.dump(train, open(f"{path_to_save}_SPECTRA_splits/SP_{spectral_parameter}_{i}/train.pkl", "wb")) + pickle.dump(test, open(f"{path_to_save}_SPECTRA_splits/SP_{spectral_parameter}_{i}/test.pkl", "wb")) + pickle.dump(stats, open(f"{path_to_save}_SPECTRA_splits/SP_{spectral_parameter}_{i}/stats.pkl", "wb")) else: print(f"Split for SP_{spectral_parameter}_{i} could not be generated since independent set only has one sample") return splits - def return_split_stats(self, spectral_parameter, number): - split_folder = f"./{self.dataset.name}_SPECTRA_splits/SP_{spectral_parameter}_{number}" + def return_split_stats(self, spectral_parameter: float, + number: int, + path_to_save: str = None, + chunksize: int = 10000000, + show_progress: bool = False) -> Dict: + + if path_to_save is None: + path_to_save = f"{self.dataset.name}_SPECTRA_splits" + split_folder = f"./{path_to_save}/SP_{spectral_parameter}_{number}" + else: + split_folder = f"{path_to_save}/SP_{spectral_parameter}_{number}" + if not os.path.exists(split_folder): raise Exception(f"Split folder {split_folder} does not exist") else: + if not os.path.exists(f"{split_folder}/stats.pkl"): + train = pickle.load(open(f"{split_folder}/train.pkl", "rb")) + test = pickle.load(open(f"{split_folder}/test.pkl", "rb")) + stats = self.get_stats(train, test, spectral_parameter, chunksize, show_progress) + pickle.dump(stats, open(f"{split_folder}/stats.pkl", "wb")) + return stats + return pickle.load(open(f"{split_folder}/stats.pkl", "rb")) - def return_split_samples(self, spectral_parameter, number): - split_folder = f"./{self.dataset.name}_SPECTRA_splits/SP_{spectral_parameter}_{number}" + def return_split_samples(self, spectral_parameter: float, + number: int, + path_to_save: str = None): + + if path_to_save is None: + path_to_save = f"{self.dataset.name}_SPECTRA_splits" + + split_folder = f"./{path_to_save}/SP_{spectral_parameter}_{number}" if not os.path.exists(split_folder): raise Exception(f"Split folder {split_folder} does not exist") else: @@ -210,64 +309,135 @@ def return_split_samples(self, spectral_parameter, number): test = pickle.load(open(f"{split_folder}/test.pkl", "rb")) return [self.dataset[int(i)] for i in train], [self.dataset[int(i)] for i in test] - def return_all_split_stats(self): + def return_all_split_stats(self, + path_to_save: str = None, + chunksize: int = 10000000, + show_progress: bool = False) -> Dict: + + if path_to_save is None: + path_to_save = f"{self.dataset.name}_SPECTRA_splits" + SP = [] numbers = [] train_size = [] test_size = [] cross_split_overlap = [] + if not self.binary: + std_css = [] + max_css = [] + min_css = [] - for folder in os.listdir(f"{self.dataset.name}_SPECTRA_splits"): + if not show_progress: + to_iterate = os.listdir(path_to_save) + else: + to_iterate = tqdm(os.listdir(path_to_save)) + + for folder in to_iterate: spectral_parameter = folder.split('_')[1] number = folder.split('_')[2] - res = self.return_split_stats(spectral_parameter, number) + res = self.return_split_stats(spectral_parameter, number, chunksize=chunksize, path_to_save=path_to_save, show_progress=show_progress) SP.append(float(spectral_parameter)) numbers.append(int(number)) train_size.append(int(res['train_size'])) test_size.append(int(res['test_size'])) cross_split_overlap.append(float(res['cross_split_overlap'])) + if not self.binary: + std_css.append(float(res['std_css'])) + max_css.append(float(res['max_css'])) + min_css.append(float(res['min_css'])) stats = {'SPECTRA_parameter': SP, 'number': number, 'train_size': train_size, 'test_size': test_size, 'cross_split_overlap': cross_split_overlap} + if not self.binary: + stats['std_css'] = std_css + stats['max_css'] = max_css + stats['min_css'] = min_css + pickle.dump(stats, open(f"{path_to_save}/all_stats.pkl", "wb")) + plot_split_stats(stats = stats) return stats - def pre_calculate_spectra_properties(self, filename, force_recalculate = False): - if os.path.exists(f"{filename}_precalculated_spectra_properties") and not force_recalculate: - print(f"File {filename}_precalculated_spectra_properties already exists, set force_recalculate to True to recalculate") - else: - similarity_file = open(f"{filename}_precalculated_spectra_properties", 'w') + # def find_closest(self, overlap, x, y): + # min_difference = 1000000 + # best_overlap = None + # best_param = None + + + # for i,j in zip(x, y): + # if abs(overlap - j) < min_difference: + # best_param = i + # best_overlap = j + # min_difference = abs(overlap - j) + + # print(f"{best_overlap} and {best_param} for {overlap}") - for i in tqdm(range(len(self.dataset))): - for j in range(i+1, len(self.dataset)): - similarity_file.write(f"{i}\t{j}\t{self.spectra_properties(self.dataset[i], self.dataset[j])}\n") - - similarity_file.close() - self.load_spectra_precalculated_spectra_properties(filename) - - def load_spectra_precalculated_spectra_properties(self, filename): - if not os.path.exists(f"{filename}_precalculated_spectra_properties"): - raise Exception(f"File {filename}_precalculated_spectra_properties does not exist") + +class Spectra_Property_Graph_Constructor(): + def __init__(self, spectra: Spectra, + dataset: SpectraDataset, + num_chunks: int = 0, + binary: bool = False): + self.spectra = spectra + self.dataset = dataset + self.num_chunks = num_chunks + if self.num_chunks != 0: + self.data_chunk = np.array_split(list(range(len(self.dataset))), self.num_chunks) else: - self.spectra_properties_loaded = pd.read_csv(f"{filename}_precalculated_spectra_properties", sep = '\t', header = None) + self.data_chunk = [list(range(len(self.dataset)))] + self.binary = binary + + def create_adjacency_matrix(self, chunk_num: int): + to_store = [] - self.non_lookup_spectra_property = self.spectra_properties + for i in tqdm(self.data_chunk[chunk_num]): + for j in range(i, len(self.dataset)): + if i != j: + if self.binary: + if self.spectra.spectra_properties(self.dataset[i], self.dataset[j]): + to_store.append(1) + else: + to_store.append(0) + else: + to_store.append(self.spectra.spectra_properties(self.dataset[i], self.dataset[j])) + + if not os.path.exists('adjacency_matrices'): + os.makedirs('adjacency_matrices') + + with open(f'adjacency_matrices/aj_{chunk_num}.npy', 'wb') as f: + pickle.dump(to_store, f) + + def combine_adjacency_matrices(self): + num_adjacency = len(os.listdir('adjacency_matrices')) + if self.num_chunks == 0: + if num_adjacency != 1: + raise Exception("Need to generate adjacency matrices first! See documentation") + else: + if num_adjacency != self.num_chunks: + raise Exception("Need to generate adjacency matrices first! See documentation") + + n = len(self.dataset) + new = np.zeros(int((n*(n-1))/2)) + previous_start = 0 - def lookup_spectra_property(x, y): - if not is_integer(x) or not is_integer(y): - return self.non_lookup_spectra_property(x, y) + for i in tqdm(range(self.num_chunks)): + to_assign = np.load(f'adjacency_matrices/aj_{i}.npy', allow_pickle=True) + new[previous_start:previous_start+len(to_assign)] = to_assign + previous_start += len(to_assign) + if self.binary: + new = new.astype(np.int8) else: - res1 = self.spectra_properties_loaded[(self.spectra_properties_loaded[0] == x) & (self.spectra_properties_loaded[1] == y)] - res2 = self.spectra_properties_loaded[(self.spectra_properties_loaded[0] == y) & (self.spectra_properties_loaded[1] == x)] - if len(res1) > 0: - return res1[2].values[0] - elif len(res2) > 0: - return res2[2].values[0] - else: - raise Exception(f"SPECTRA property between {x} and {y} not found in precalculated file") - - self.spectra_properties = lookup_spectra_property + new = new.astype(np.float16) + + if self.num_chunks == 0: + new = np.load(f'adjacency_matrices/aj_0.npy', allow_pickle=True) + if self.binary: + torch.save(torch.tensor(new).to(torch.int8), 'flattened_adjacency_matrix.pt') + else: + torch.save(torch.tensor(new).half(), 'flattened_adjacency_matrix.pt') + + + diff --git a/spectrae/utils.py b/spectrae/utils.py index 890723b..1c3dd30 100644 --- a/spectrae/utils.py +++ b/spectrae/utils.py @@ -1,12 +1,236 @@ -from networkx.algorithms.components import connected_components +import torch +import numpy as np +from tqdm import tqdm +from typing import List, Tuple, Dict, Union, Optional +import os +import matplotlib.pyplot as plt +import pickle -def is_integer(n): - if isinstance(n, int): - return True - elif isinstance(n, float): - return n.is_integer() - else: - return False +class FlattenedAdjacency: + def __init__(self, + flattened_adjacency_path: str): + + self.flattened_adjacency = torch.load(flattened_adjacency_path) + if torch.cuda.is_available(): + self.flattened_adjacency = self.flattened_adjacency.cuda() + self.n = self.get_number_len(len(self.flattened_adjacency)) + if self.flattened_adjacency.dtype is torch.int8: + self.binary = True + elif self.flattened_adjacency.dtype is torch.float16: + self.binary = False + else: + raise ValueError("Invalid datatype. Use torch.int8 or torch.bfloat16.") + + def get_number_len(self, number_items): + return int(abs(-1-np.sqrt(1+8*number_items))/2) + + def return_index_flat(self, i, j): + if i == j: + return 1 + + if i > j: + i, j = j, i + + sum = i * self.n - (i * (i + 1)) // 2 + return sum + j - i - 1 + + def __len__(self): + return self.n + + def __getitem__(self, indices): + if isinstance(indices, tuple) and len(indices) == 2: + i, j = indices + return self.flattened_adjacency[self.return_index_flat(i, j)].to(torch.int64) + elif isinstance(indices, list): + to_index = [] + for i, j in indices: + to_index.append(self.return_index_flat(i, j)) + if self.binary: + return self.flattened_adjacency[to_index].to(torch.int64) + return self.flattened_adjacency[to_index].to(torch.float32) + + else: + raise IndexError("Invalid index. Use FlattenedAdjacency[i, j] or FlattenedAdjacency[[i,j],[k,l]] for indexing.") + +class Spectral_Property_Graph: + def __init__(self, + flattened_adjacency: FlattenedAdjacency): + + self.flattened_adjacency = flattened_adjacency + self.degree_distribution = None + self.binary = self.flattened_adjacency.binary + + def num_nodes(self): + return len(self.flattened_adjacency) + + def num_edges(self): + num_nodes = self.num_nodes() + return int(num_nodes*(num_nodes-1)/2) + + def chunked_sum(self, chunk_size: int): + """ + Sums a tensor in chunks to avoid OOM errors. + + Args: + tensor (torch.Tensor): The input tensor to be summed. + chunk_size (int): The size of each chunk. + + Returns: + torch.Tensor: The sum of the tensor. + """ + tensor = self.flattened_adjacency.flattened_adjacency + total_sum = torch.zeros_like(tensor[0], dtype=torch.int64) + num_chunks = (tensor.size(0) + chunk_size - 1) // chunk_size # Calculate the number of chunks + + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min((i + 1) * chunk_size, tensor.size(0)) + chunk = tensor[start_idx:end_idx] + total_sum += chunk.sum(dim=0) + + return total_sum + + def get_density(self, chunk_size: int = 1000000000, return_sum: bool = False): + result = self.chunked_sum(chunk_size) + if return_sum: + return result, result/self.num_edges() + return result/self.num_edges() + + def get_degree(self, node: int): + n = self.num_nodes() + + indexes = [] + for i in range(n): + indexes.append((i, node)) + #values.append(self.flattened_adjacency[i, node]) + + values = self.flattened_adjacency[indexes].type(torch.int64) + return torch.sum(values), torch.sum(values)/n + #return sum(values), sum(values)/n + + def get_degree_distribution(self, + track: bool = False, + save: bool = False, + name: str = "degree_distribution", + ) -> Dict[int, int]: + + if f"{name}.pt" in os.listdir(): + if not torch.cuda.is_available(): + self.degree_distribution = torch.load(f"{name}.pt", map_location = "cpu") + self.degree_distribution = torch.load(f"{name}.pt", map_location = "cuda") + self.sorted_minimum_keys = sorted(self.degree_distribution, key=lambda k: self.degree_distribution[k]) + return self.degree_distribution + + n = self.num_nodes() + + degrees = {} + if track: + to_iterate = tqdm(range(n)) + else: + to_iterate = range(n) + + for i in to_iterate: + degrees[i] = self.get_degree(i)[0] + + if save: + torch.save(degrees, f"{name}.pt") + + self.degree_distribution = degrees + self.sorted_minimum_keys = sorted(degrees, key=lambda k: self.degree_distribution[k]) + return self.degree_distribution + + def get_minimum_degree_node(self, deleted_nodes: List[int] = None) -> Tuple[int, int]: + if self.degree_distribution is None: + self.degree_distribution = self.get_degree_distribution(track = True, save = True) + + current_index = 0 + + while current_index < len(self.sorted_minimum_keys): + key = self.sorted_minimum_keys[current_index] + current_index += 1 + if deleted_nodes is None or key not in deleted_nodes: + deleted_nodes = yield key, self.degree_distribution[key] + deleted_nodes = set(deleted_nodes) + + def get_weight(self, i: int, j:int): + return self.flattened_adjacency[i, j] + + def get_weights(self, indices: List[Tuple[int, int]]): + return self.flattened_adjacency[indices] -def is_clique(G): - return G.size() == (G.order()*(G.order()-1))/2 \ No newline at end of file + def get_stats(self): + return self.num_nodes, self.num_edges, self.get_density(return_sum = False) + + def max(self): + return torch.max(self.flattened_adjacency.flattened_adjacency) + + def min(self): + return torch.min(self.flattened_adjacency.flattened_adjacency) + + + +def cross_split_overlap(split, g): + binary = g.binary + if binary: + num_similar = 0 + for i in range(len(split)): + for j in range(i+1, len(split)): + if g.get_weight(split[i], split[j]) > 0: + num_similar += 1 + break + return num_similar/len(split), num_similar, len(split) + else: + index_to_gather = [] + for i in range(len(split)): + for j in range(i+1, len(split)): + index_to_gather.append((split[i], split[j])) + if len(index_to_gather) > 100000000: + values = g.get_weights(index_to_gather) + return torch.mean(values).item(), torch.std(values).item(), torch.max(values).item(), torch.min(values).item() + + index_to_gather = torch.tensor(index_to_gather).cuda() + values = g.get_weights(index_to_gather) + return torch.mean(values).item(), torch.std(values).item(), torch.max(values).item(), torch.min(values).item() + +def plot_split_stats(stats_file: str = None, name: str = None, stats: Optional[Dict] = None): + if stats is None: + with open(stats_file, 'rb') as f: + stats = pickle.load(f) + + spectral_parameter = stats['SPECTRA_parameter'] + train_length = stats['train_size'] + test_length = stats['test_size'] + css = stats['cross_split_overlap'] + + # Convert spectral_parameter to a numeric type if necessary + spectral_parameter = list(map(float, spectral_parameter)) + + # Create the scatter plots + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8)) + + # Dataset size vs Spectral parameter + ax1.scatter(spectral_parameter, train_length, color='blue', label='Train') + ax1.scatter(spectral_parameter, test_length, color='green', label='Test') + ax1.set_title('Train and test set size vs Spectral Parameter') + ax1.set_xlabel('Spectral Parameter') + ax1.set_ylabel('Dataset Size') + ax1.legend() + + # Cross split overlap vs Spectral parameter + ax2.scatter(spectral_parameter, css, color='red') + if 'std_css' in stats: + ax2.errorbar(spectral_parameter, css, yerr=stats['std_css'], fmt='o', color='red', ecolor='lightgray', elinewidth=2, capsize=3) + ax2.scatter(spectral_parameter, stats['max_css'], color='purple', label='Max CSS') + ax2.legend() + + ax2.set_title('Cross Split Overlap vs Spectral Parameter') + ax2.set_xlabel('Spectral Parameter') + ax2.set_ylabel('Cross Split Overlap') + + # Adjust layout and save the plot + plt.tight_layout() + if name is not None: + plt.savefig(f'{name}.png') + else: + plt.savefig('split_stats.png') + plt.show()