diff --git a/R/ggsurvfit_align_plots.R b/R/ggsurvfit_align_plots.R index 7511227..6b66d8f 100644 --- a/R/ggsurvfit_align_plots.R +++ b/R/ggsurvfit_align_plots.R @@ -107,4 +107,4 @@ ggsurvfit_align_plots <- function(pltlist) { plots_grobs_xcols[[1]]$grobs[[13]]$children[[1]]$x <- grid::unit(x, "cm") plots_grobs_xcols -} +} \ No newline at end of file diff --git a/R/utils-add_risktable.R b/R/utils-add_risktable.R index ffd3a28..6d061f2 100644 --- a/R/utils-add_risktable.R +++ b/R/utils-add_risktable.R @@ -3,26 +3,25 @@ combine_groups, risktable_group, risktable_height, theme, combine_plots, risktable_symbol_args, ...) { - # check iputs ---------------------------------------------------------------- + # check inputs ---------------------------------------------------------------- if (!is.null(risktable_height) && (length(risktable_height) > 1 || !is.numeric(risktable_height) || !dplyr::between(risktable_height, 0, 1))) { cli_abort("The {.code add_risktable(risktable_height=)} argument must be a scalar between 0 and 1.") } - + # build the ggplot to inspect the internals ---------------------------------- plot_build <- suppressWarnings(ggplot2::ggplot_build(x)) - + # if plot is faceted, return plot without risktable -------------------------- if (.is_faceted(plot_build)) { return(structure(x, class = setdiff(class(x), c("ggsurvfit", "ggcuminc")))) } - + # get data to place in risktables -------------------------------------------- times <- times %||% plot_build$layout$panel_params[[1]]$x$breaks df_times <- .prepare_data_for_risk_tables(data = x$data, times = times, combine_groups = combine_groups) - - + # determine grouping if not specified ---------------------------------------- if (risktable_group == "auto") { risktable_group <- @@ -32,14 +31,63 @@ TRUE ~ "strata" ) } - + # determine risktable height ------------------------------------------------- risktable_height <- .calculate_risktable_height(risktable_height, risktable_group, risktable_stats, df_times) - + # create list of ggplots, one plot for each risktable ------------------------ df_stat_labels <- .construct_stat_labels(risktable_stats, stats_label) - + + # PATCHWORK::FREE() APPROACH WITH COORDINATE INTEGRATION -------------------- + if (isTRUE(combine_plots)) { + # Extract coordinate system from main plot BEFORE risk table construction + main_x_breaks <- plot_build$layout$panel_params[[1]]$x$breaks + main_x_range <- plot_build$layout$panel_params[[1]]$x$range + + # Create risk tables WITH coordinate system built in (user's approach) + gg_risktable_list <- + .create_list_of_gg_risk_tables( + df_times, risktable_stats, times, + df_stat_labels, theme, risktable_group, + color_block_mapping = + .match_strata_level_to_color(plot_build, risktable_group, risktable_symbol_args), + risktable_symbol_args = risktable_symbol_args, + x_breaks = main_x_breaks, # Pass coordinate info + x_range = main_x_range, # Pass coordinate info + ... + ) + + # Apply patchwork::free() to main plot (prevents y-axis title shifting) + main_plot_free <- patchwork::free(x, type = "space", side = "l") + + # Combine using patchwork exactly like the user's successful example + if (length(gg_risktable_list) == 1) { + # Single risk table case + gg_combined <- main_plot_free / gg_risktable_list[[1]] + gg_combined <- gg_combined + patchwork::plot_layout( + heights = c(1 - risktable_height, risktable_height) + ) + } else { + # Multiple risk tables case + gg_combined <- main_plot_free + for (i in seq_along(gg_risktable_list)) { + gg_combined <- gg_combined / gg_risktable_list[[i]] + } + + # Calculate heights + n_tables <- length(gg_risktable_list) + table_height_each <- risktable_height / n_tables + all_heights <- c(1 - risktable_height, rep(table_height_each, n_tables)) + + gg_combined <- gg_combined + patchwork::plot_layout(heights = all_heights) + } + + return(gg_combined) + } + + # FALLBACK: ORIGINAL METHOD FOR combine_plots = FALSE ----------------------- + # Create risk tables without coordinate integration for backward compatibility gg_risktable_list <- .create_list_of_gg_risk_tables( df_times, risktable_stats, times, @@ -47,17 +95,18 @@ color_block_mapping = .match_strata_level_to_color(plot_build, risktable_group, risktable_symbol_args), risktable_symbol_args = risktable_symbol_args, + # No coordinate parameters for original method ... ) - + # align all the plots -------------------------------------------------------- gg_risktable_list_aligned <- c(list(x), gg_risktable_list) %>% ggsurvfit_align_plots() - + # combine all plots into single figure --------------------------------------- if (isFALSE(combine_plots)) return(gg_risktable_list_aligned) - + risktable_n <- length(gg_risktable_list_aligned) - 1 gg_final <- gg_risktable_list_aligned %>% @@ -67,7 +116,7 @@ c(1 - risktable_height, rep_len(risktable_height / risktable_n, length.out = risktable_n)) ) - + gg_final } @@ -162,7 +211,9 @@ lst_stat_labels_default <- df_stat_labels, theme, risktable_group, color_block_mapping, - risktable_symbol_args, ...) { + risktable_symbol_args, + x_breaks = NULL, + x_range = NULL, ...) { grouping_variable <- switch(risktable_group, "strata" = "strata", @@ -224,6 +275,13 @@ lst_stat_labels_default <- ) + rlang::inject(ggplot2::geom_text(!!!geom_text_args)) + # Apply coordinate system during construction + if (!is.null(x_breaks) && !is.null(x_range)) { + gg <- gg + + ggplot2::scale_x_continuous(breaks = x_breaks) + + ggplot2::coord_cartesian(xlim = x_range, expand = FALSE, clip = "off") + } + # apply styling to the plot gg + ggtitle_group_lbl + diff --git a/tests/testthat/test-add_risktable.R b/tests/testthat/test-add_risktable.R index c908574..f8dfe81 100644 --- a/tests/testthat/test-add_risktable.R +++ b/tests/testthat/test-add_risktable.R @@ -336,12 +336,12 @@ test_that("add_risktable() works with ggsurvfit() `start.time` and negative time test_that("add_risktable() works with multiple survival endpoints (Issue #212)", { - + os_data <- df_lung %>% dplyr::mutate(PARAM = "Overall Survival") pfs_data <- df_lung %>% dplyr::mutate(time = time * 0.7, PARAM = "Progression-Free Survival") combined_data <- dplyr::bind_rows(os_data, pfs_data) - - + + expect_error( p <- survfit2(Surv(time, status) ~ PARAM, data = combined_data) %>% ggsurvfit() + add_risktable(), @@ -350,3 +350,122 @@ test_that("add_risktable() works with multiple survival endpoints (Issue #212)", expect_error(print(p), NA) }) + +test_that("add_risktable() handles large numbers and long labels without overlapping (Issue #230)", { + + # Large patient cohort with descriptive strata labels + set.seed(123) # For reproducible results + + large_cohort_data <- data.frame( + time = c( + # Extended Time Since Surgery group - longer survival times + rexp(800, rate = 0.15), + # Limited Time Since Surgery group - shorter survival times + rexp(1200, rate = 0.25) + ), + status = c( + rbinom(800, 1, 0.6), # 60% event rate for extended group + rbinom(1200, 1, 0.75) # 75% event rate for limited group + ), + surgery_timing = factor(c( + rep("Extended Time Since Surgery", 800), + rep("Limited Time Since Surgery", 1200) + )) + ) + + # Create survfit object with large numbers + sf_large_cohort <- survfit2(Surv(time, status) ~ surgery_timing, data = large_cohort_data) + + # Large numbers at time 0: ~800 and ~1200 patients at risk + expect_error( + p_issue_230 <- sf_large_cohort %>% + ggsurvfit() + + add_risktable(risktable_stats = "n.risk"), + NA + ) + + expect_error(print(p_issue_230), NA) + + # Test with the format from the user's image: "At risk (censored)" + expect_error( + p_issue_230_with_censored <- sf_large_cohort %>% + ggsurvfit() + + add_risktable( + risktable_stats = "{n.risk} ({cum.censor})", + stats_label = "At risk (censored)" + ), + NA + ) + + expect_error(print(p_issue_230_with_censored), NA) + + # Test that the plot actually has large numbers at time 0 + risk_data <- sf_large_cohort %>% tidy_survfit(times = 0) + expect_true(any(risk_data$n.risk >= 500), + info = "Should have large patient numbers at baseline") + + # Test with even longer strata names that would definitely cause issues + very_long_labels_data <- large_cohort_data %>% + dplyr::mutate( + surgery_timing = factor( + surgery_timing, + levels = c("Extended Time Since Surgery", "Limited Time Since Surgery"), + labels = c( + "Extended Time Between Surgery and Treatment Initiation", + "Limited Time Between Surgery and Treatment Initiation" + ) + ) + ) + + sf_very_long <- survfit2(Surv(time, status) ~ surgery_timing, data = very_long_labels_data) + + expect_error( + p_very_long_labels <- sf_very_long %>% + ggsurvfit() + + add_risktable(risktable_stats = "n.risk"), + NA + ) + + expect_error(print(p_very_long_labels), NA) + + # Skip visual tests on CI but include them for local testing + skip_on_ci() + vdiffr::expect_doppelganger("issue-230-large-numbers", p_issue_230) + vdiffr::expect_doppelganger("issue-230-with-censored", p_issue_230_with_censored) + vdiffr::expect_doppelganger("very-long-labels", p_very_long_labels) +}) + +# Additional test specifically for the overlapping issue +test_that("add_risktable() prevents text overlapping with patchwork::free()", { + # Create a scenario guaranteed to cause overlapping without the fix + overlap_data <- data.frame( + time = rexp(2000, 0.1), # Very large cohort + status = rbinom(2000, 1, 0.5), + group = factor(c( + rep("Group with extremely long descriptive name that would cause overlap", 1000), + rep("Another group with very long name causing alignment issues", 1000) + )) + ) + + sf_overlap <- survfit2(Surv(time, status) ~ group, data = overlap_data) + + # This would definitely cause overlapping without patchwork::free() + expect_error( + p_overlap_test <- sf_overlap %>% + ggsurvfit() + + add_risktable(risktable_stats = "n.risk") + + # Force narrow margins to test the alignment fix + theme(plot.margin = margin(0.1, 0.1, 0.1, 0.1, "cm")), + NA + ) + + expect_error(print(p_overlap_test), NA) + + # Test that numbers at time 0 are indeed large (>1000) + baseline_risk <- sf_overlap %>% tidy_survfit(times = 0) + expect_true(all(baseline_risk$n.risk >= 900), + info = "All groups should have large patient numbers") + + skip_on_ci() + vdiffr::expect_doppelganger("overlap-prevention-test", p_overlap_test) +})