diff --git a/R/visualization.R b/R/visualization.R index 65ef5c94a..c363c9cf1 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -34,10 +34,11 @@ setGeneric( #' @param balanced Plot an equal number of genes with both + and - scores. #' @param projected Use the full projected dimensional reduction #' @param ncol Number of columns to plot -#' @param fast If true, use \code{image} to generate plots; faster than using ggplot2, but not customizable +#' @param fast If true, use \code{image} to generate plots; faster than using ggplot2, but not customizable and excludes figure legend in output #' @param assays A vector of assays to pull data from -#' @param combine Combine plots into a single \code{\link[patchwork]{patchwork}ed} +#' @param combine Combine plots into a single \code{\link[patchwork]{patchwork}ed} with single shared figure legend when \code{fast=FALSE} #' ggplot object. If \code{FALSE}, return a list of ggplot objects +#' @param leg.pos When \code{combine=TRUE}, allows legend position to be adjusted for \code{\link[patchwork]{patchwork}ed} output; defaults as \code{"right"} #' #' @return No return value by default. If using fast = FALSE, will return a #' \code{\link[patchwork]{patchwork}ed} ggplot object if combine = TRUE, otherwise @@ -54,21 +55,22 @@ setGeneric( #' DimHeatmap(object = pbmc_small) #' DimHeatmap <- function( - object, - dims = 1, - nfeatures = 30, - cells = NULL, - reduction = 'pca', - disp.min = -2.5, - disp.max = NULL, - balanced = TRUE, - projected = FALSE, - ncol = NULL, - fast = TRUE, - raster = TRUE, - slot = 'scale.data', - assays = NULL, - combine = TRUE + object, + dims = 1, + nfeatures = 30, + cells = NULL, + reduction = 'pca', + disp.min = -2.5, + disp.max = NULL, + balanced = TRUE, + projected = FALSE, + ncol = NULL, + fast = TRUE, + raster = TRUE, + slot = 'scale.data', + assays = NULL, + combine = TRUE, + leg.pos = "right" ) { ncol <- ncol %||% ifelse(test = length(x = dims) > 2, yes = 3, no = length(x = dims)) plots <- vector(mode = 'list', length = length(x = dims)) @@ -171,6 +173,9 @@ DimHeatmap <- function( cell.order = dim.cells, feature.order = dim.features ) + plots[[i]] <- plots[[i]] + + ggtitle(paste0(Key(object = object[[reduction]]), dims[i])) + + theme(plot.title = element_text(hjust = 0.5, face = "bold")) } } if (fast) { @@ -178,7 +183,9 @@ DimHeatmap <- function( return(invisible(x = NULL)) } if (combine) { - plots <- wrap_plots(plots, ncol = ncol, guides = "collect") + plots <- wrap_plots(plots, ncol = ncol, guides = "collect") + + plot_layout(guides = "collect") & + theme(legend.position = leg.pos) } return(plots) }