diff --git a/R/sketching.R b/R/sketching.R index 8bb260595..8efc5dd88 100644 --- a/R/sketching.R +++ b/R/sketching.R @@ -20,12 +20,14 @@ NULL #' #' @param object A Seurat object. #' @param assay Assay name. Default is NULL, in which case the default assay of the object is used. -#' @param ncells A positive integer indicating the number of cells to sample for the sketching. Default is 5000. +#' @param cell.ratio Proportion of cells to sample from each layer. +#' @min.cells Minimum cells a layer must have in order to be subsampled. #' @param sketched.assay Sketched assay name. A sketch assay is created or overwrite with the sketch data. Default is 'sketch'. #' @param method Sketching method to use. Can be 'LeverageScore' or 'Uniform'. #' Default is 'LeverageScore'. #' @param var.name A metadata column name to store the leverage scores. Default is 'leverage.score'. #' @param over.write whether to overwrite existing column in the metadata. Default is FALSE. +#' @leverage.already.calculated whether leverage scores have already been calculated. #' @param seed A positive integer for the seed of the random number generator. Default is 123. #' @param cast The type to cast the resulting assay to. Default is 'dgCMatrix'. #' @param verbose Print progress and diagnostic messages @@ -41,50 +43,51 @@ NULL SketchData <- function( object, assay = NULL, - ncells = 5000L, + cell.ratio = 0.25, + min.cells = 2500, sketched.assay = 'sketch', method = c('LeverageScore', 'Uniform'), var.name = "leverage.score", - over.write = FALSE, + over.write = F, + leverage.already.calculated = F, seed = 123L, cast = 'dgCMatrix', - verbose = TRUE, + verbose = T, ... ) { assay <- assay[1L] %||% DefaultAssay(object = object) - assay <- match.arg(arg = assay, choices = Assays(object = object)) + assay <- match.arg(arg = assay, choices = SeuratObject::Assays(object = object)) method <- match.arg(arg = method) if (sketched.assay == assay) { - abort(message = "Cannot overwrite existing assays") + rlang::abort(message = "Cannot overwrite existing assays") } - if (sketched.assay %in% Assays(object = object)) { + if (sketched.assay %in% SeuratObject::Assays(object = object)) { if (sketched.assay == DefaultAssay(object = object)) { DefaultAssay(object = object) <- assay } object[[sketched.assay]] <- NULL } - if (!over.write) { - var.name <- CheckMetaVarName(object = object, var.name = var.name) - } - - if (method == 'LeverageScore') { - if (verbose) { - message("Calcuating Leverage Score") - } - object <- LeverageScore( - object = object, - assay = assay, - var.name = var.name, - over.write = over.write, - seed = seed, - verbose = FALSE, - ... - ) - } else if (method == 'Uniform') { - if (verbose) { - message("Uniformly sampling") + + if (over.write == T | leverage.already.calculated == F) { + if (method == 'LeverageScore') { + if (verbose) { + message("Calcuating Leverage Score") + } + object <- LeverageScore( + object = object, + assay = assay, + var.name = var.name, + over.write = over.write, + seed = seed, + verbose = verbose, + ... + ) + } else if (method == 'Uniform') { + if (verbose) { + message("Uniformly sampling") + } + object[[var.name]] <- 1 } - object[[var.name]] <- 1 } leverage.score <- object[[var.name]] layers.data <- Layers(object = object[[assay]], search = 'data') @@ -93,12 +96,14 @@ SketchData <- function( FUN = function(i, seed) { set.seed(seed = seed) lcells <- Cells(x = object[[assay]], layer = layers.data[i]) - if (length(x = lcells) < ncells) { - return(lcells) + if (length(lcells) < min.cells) { + ncells_per_sample = length(lcells) + } else { + ncells_per_sample = max(round(length(lcells)*cell.ratio), min.cells) } return(sample( x = lcells, - size = ncells, + size = ncells_per_sample, prob = leverage.score[lcells,] )) }, @@ -113,13 +118,13 @@ SketchData <- function( try( expr = VariableFeatures(object = sketched, method = "sketch", layer = lyr) <- VariableFeatures(object = object[[assay]], layer = lyr), - silent = TRUE + silent = F ) } if (!is.null(x = cast) && inherits(x = sketched, what = 'Assay5')) { sketched <- CastAssay(object = sketched, to = cast, ...) } - Key(object = sketched) <- Key(object = sketched.assay, quiet = TRUE) + Key(object = sketched) <- Key(object = sketched.assay, quiet = F) object[[sketched.assay]] <- sketched DefaultAssay(object = object) <- sketched.assay return(object) @@ -369,6 +374,7 @@ TransferSketchLabels <- function( #' @param seed A positive integer. The seed for the random number generator, defaults to 123. #' @param verbose Print progress and diagnostic messages #' @importFrom Matrix qrR t +#' @importFrom matrixcalc is.singular.matrix #' @importFrom irlba irlba #' #' @rdname LeverageScore @@ -448,23 +454,30 @@ LeverageScore.default <- function( } else { base::qr.R(qr = qr.sa) } - R.inv <- as.sparse(x = backsolve(r = R, x = diag(x = ncol(x = R)))) - if (isTRUE(x = verbose)) { - message("Performing random projection") - } - JL <- as.sparse(x = JLEmbed( - nrow = ncol(x = R.inv), - ncol = ndims, - eps = eps, - seed = seed - )) - Z <- object %*% (R.inv %*% JL) - if (inherits(x = Z, what = 'IterableMatrix')) { - Z.score <- BPCells::matrix_stats(matrix = Z ^ 2, row_stats = 'mean' - )$row_stats['mean',]*ncol(x = Z) - } else { - Z.score <- rowSums(x = Z ^ 2) - } + A <- diag(x = R) + if (any(A == 0)) { + bad_elem <- which(A == 0) + message(paste0("Found 0 in diagonal of input matrix at ", bad_elem, ". Assigning all cells leverage score of 1")) + Z.score <- rep(1, nrow(x = object)) + } else { + R.inv <- as.sparse(x = backsolve(r = R, x = diag(x = ncol(x = R)))) + if (isTRUE(x = verbose)) { + message("Performing random projection") + } + JL <- as.sparse(x = JLEmbed( + nrow = ncol(x = R.inv), + ncol = ndims, + eps = eps, + seed = seed + )) + Z <- object %*% (R.inv %*% JL) + if (inherits(x = Z, what = 'IterableMatrix')) { + Z.score <- BPCells::matrix_stats(matrix = Z ^ 2, row_stats = 'mean' + )$row_stats['mean',]*ncol(x = Z) + } else { + Z.score <- rowSums(x = Z ^ 2) + } + } return(Z.score) }