|
| 1 | +import sys |
| 2 | +import mudata as mu |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | + |
| 6 | +## VIASH START |
| 7 | +par = { |
| 8 | + "input": "test_with_probabilities.h5mu", |
| 9 | + "modality": "rna", |
| 10 | + "input_obs_predictions": ["scanvi_pred", "celltypist_pred", "singler_pred"], |
| 11 | + "input_obs_probabilities": ["scanvi_prob", "celltypist_prob", "singler_prob"], |
| 12 | + "weights": None, |
| 13 | + "tie_label": None, |
| 14 | + "output": "consensus_test_output.h5mu", |
| 15 | + "output_obs_predictions": "consensus_pred", |
| 16 | + "output_obs_score": "consensus_score", |
| 17 | + "output_compression": "gzip", |
| 18 | +} |
| 19 | +meta = {"resources_dir": "src/utils"} |
| 20 | +## VIASH END |
| 21 | + |
| 22 | +sys.path.append(meta["resources_dir"]) |
| 23 | +from setup_logger import setup_logger |
| 24 | +from compress_h5mu import write_h5ad_to_h5mu_with_compression |
| 25 | + |
| 26 | +logger = setup_logger() |
| 27 | + |
| 28 | + |
| 29 | +def main(): |
| 30 | + prediction_cols = par["input_obs_predictions"] |
| 31 | + prob_cols = par["input_obs_probabilities"] |
| 32 | + weights = par["weights"] |
| 33 | + |
| 34 | + if weights and len(weights) != len(prediction_cols): |
| 35 | + raise ValueError( |
| 36 | + f"--weights must have the same length as --input_obs_predictions. " |
| 37 | + f"Got {len(weights)} weights for {len(prediction_cols)} prediction columns." |
| 38 | + ) |
| 39 | + if prob_cols and len(prob_cols) != len(prediction_cols): |
| 40 | + raise ValueError( |
| 41 | + f"--input_obs_probabilities must have the same length as --input_obs_predictions. " |
| 42 | + f"Got {len(prob_cols)} probability columns for {len(prediction_cols)} prediction columns." |
| 43 | + ) |
| 44 | + |
| 45 | + logger.info("Reading input data.") |
| 46 | + adata = mu.read_h5ad(par["input"], mod=par["modality"]) |
| 47 | + |
| 48 | + cols_to_check = [prediction_cols] |
| 49 | + if prob_cols: |
| 50 | + cols_to_check.append(prob_cols) |
| 51 | + for cols in cols_to_check: |
| 52 | + for col in cols: |
| 53 | + if col not in adata.obs.columns: |
| 54 | + raise ValueError(f"Column '{col}' not found in .obs.") |
| 55 | + |
| 56 | + # Each method is treated equally by default, unless user specific weights are provided |
| 57 | + n_methods = len(prediction_cols) |
| 58 | + logger.info("Initializing weights to matrix of ones") |
| 59 | + weights_arr = np.ones(n_methods, dtype=np.float32) |
| 60 | + if weights: |
| 61 | + logger.info("Applying user-provided weights.") |
| 62 | + weights_arr = np.array(weights, dtype=np.float32) |
| 63 | + logger.info("Normalizing weights") |
| 64 | + weights_arr = weights_arr / weights_arr.sum() |
| 65 | + |
| 66 | + # Apply the weights to the probabilities in the data |
| 67 | + weights = pd.DataFrame( |
| 68 | + [weights_arr] * adata.n_obs, index=adata.obs.index, columns=prediction_cols |
| 69 | + ) |
| 70 | + if prob_cols: |
| 71 | + logger.info("Scaling the weights with the probabilities from each method") |
| 72 | + weights = weights * adata.obs[prob_cols].astype(np.float32).to_numpy() |
| 73 | + assert pd.notna(weights).all(axis=None) |
| 74 | + |
| 75 | + logger.info("Computing weighted majority vote.") |
| 76 | + pred_df = adata.obs[prediction_cols].astype(str) |
| 77 | + |
| 78 | + # For each cell and each method (index), get the label and the weight |
| 79 | + incidences_weights = pd.DataFrame( |
| 80 | + {"label": pred_df.stack(), "weights": weights.stack()} |
| 81 | + ) |
| 82 | + # Move the label to the index, there might be duplicate indices now |
| 83 | + incidences_weights = incidences_weights.set_index("label", append=True).rename_axis( |
| 84 | + ["cell_id", "method", "label"] |
| 85 | + ) |
| 86 | + # Sum the weights per label, from this the labels with the largest weights need to be selected |
| 87 | + summed_weights = incidences_weights.groupby(level=["cell_id", "label"]).sum() |
| 88 | + # Find the weight that is the largest per group |
| 89 | + max_weight_per_group = summed_weights.groupby(level="cell_id").transform("max") |
| 90 | + # Use the value to look-up the corresponding IDs and labels |
| 91 | + max_weights_mask = summed_weights["weights"] == max_weight_per_group["weights"] |
| 92 | + entries_for_max_weights = summed_weights[max_weights_mask].reset_index( |
| 93 | + level="label" |
| 94 | + ) |
| 95 | + # Find the cases where there is a tie |
| 96 | + is_duplicated = max_weights_mask.groupby(level="cell_id").sum() > 1 |
| 97 | + # For the ties, overwrite the label. If a cell is in the frame more than once it is because of a tie. |
| 98 | + entries_for_max_weights.loc[is_duplicated, ["label"]] = par["tie_label"] |
| 99 | + # Now its safe to just take the first index in case of duplicates, since the label and the score is the same. |
| 100 | + entries_for_max_weights = entries_for_max_weights[ |
| 101 | + ~entries_for_max_weights.index.duplicated() |
| 102 | + ] |
| 103 | + # Normalize the weights |
| 104 | + normalized_scores = ( |
| 105 | + entries_for_max_weights["weights"] |
| 106 | + / incidences_weights["weights"].groupby(level="cell_id").sum() |
| 107 | + ) |
| 108 | + # Handle devision by 0 |
| 109 | + normalized_scores = normalized_scores.replace([np.inf, -np.inf], 0.0).fillna(0.0) |
| 110 | + logger.info("Moving the output to the anndata.") |
| 111 | + adata.obs[par["output_obs_predictions"]] = entries_for_max_weights["label"].astype( |
| 112 | + "category" |
| 113 | + ) |
| 114 | + adata.obs[par["output_obs_score"]] = normalized_scores |
| 115 | + |
| 116 | + logger.info("Writing output data...") |
| 117 | + write_h5ad_to_h5mu_with_compression( |
| 118 | + par["output"], par["input"], par["modality"], adata, par["output_compression"] |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +if __name__ == "__main__": |
| 123 | + main() |
0 commit comments