Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 177 additions & 79 deletions bin/02_filter_sqanti_transcripts.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ filter_protein_coding <- as.logical(Sys.getenv("PROTEIN_CODING_FILTER", "TRUE"
filter_internal_priming <- as.logical(Sys.getenv("INTERNAL_PRIMING_FILTER", "TRUE"))
filter_RTS <- as.logical(Sys.getenv("TEMPLATE_SWITCHING_FILTER", "TRUE"))
percent_polyA_threshold <- as.numeric(Sys.getenv("PERCENT_POLYA_THRESHOLD", "60"))
tclass_to_keep <- Sys.getenv("TRANSCRIPT_CLASS_KEEP", "FSM,NIC,NNC")
tclass_to_keep <- Sys.getenv("TRANSCRIPT_CLASS_KEEP", "FSM,ISM,NIC,NNC")
#structural_level <- Sys.getenv("STRUCTURE_FILTER", "strict")

# =============================================================================
Expand All @@ -57,8 +57,8 @@ stopifnot("File not found" = file.exists(sqanti_fasta))
stopifnot("File not found" = file.exists(mapping_file))

# Create additional output directory
dropout_dir = file.path(dir, "dropout")
dir.create(dropout_dir, recursive = TRUE, showWarnings = FALSE)
#dropout_dir = file.path(dir, "dropout")
#dir.create(dropout_dir, recursive = TRUE, showWarnings = FALSE)

# =============================================================================
# Helper functions
Expand All @@ -79,34 +79,129 @@ save_filtered_fasta <- function(fasta_file, keep_ids, output_file) {
message(paste0("Saved ", length(filtered_sequences), " sequences to ", basename(output_file)))
}

#' Save filtered GTF
#' @param gtf_file Input GTF file
#' @param keep_ids Transcript IDs to keep
#' @param output_file Output GTF file
save_filtered_gtf <- function(gtf_file, keep_ids, output_file) {

gtf = import(gtf_file)
filtered_gtf = gtf[gtf$transcript_id %in% keep_ids]
#' Convert GTF to colored BED12 for genome browser visualization
#'
#' @param gtf_path Path to input GTF file
#' @param output_bed Path to output BED12 file
#' @param color_by Column name for coloring (numeric, 0-1). High values = purple, low = yellow. Default: NULL (black)
gtf_to_bed12 <- function(gtf_path, output_bed, color_by = NULL) {

# Import GTF
gtf <- import(gtf_path)
df <- as.data.frame(gtf)
dt <- as.data.table(df)

# Get transcripts
tx <- dt[type == "transcript"]

# Determine what to use for blocks
has_exon <- any(dt$type == "exon")
has_cds <- any(dt$type == "CDS")

if(has_exon) {
ex <- dt[type == "exon"] # Use exons for blocks (shows UTRs)
} else if(has_cds) {
ex <- dt[type == "CDS"] # Use CDS for blocks (CDS only)
} else {
stop("GTF must contain either exon or CDS features")
}

cds <- dt[type == "CDS"] # Always get CDS for thick regions

# Set keys for fast joins
setkey(ex, transcript_id)
if(nrow(cds) > 0) setkey(cds, transcript_id)

# Sort transcripts by gene, then by ratio (descending)
if(!is.null(color_by)) {
tx[, sort_val := as.numeric(get(color_by))]
tx <- tx[order(gene_id, -sort_val)]
}

# Pre-compute colors
if(!is.null(color_by)) {
vals <- as.numeric(tx[[color_by]])
col_idx <- round((1 - vals) * 99) + 1
col_idx[is.na(col_idx)] <- 50
hex_colors <- viridis::viridis(100)[col_idx]
rgb_matrix <- col2rgb(hex_colors)
tx$color <- paste(rgb_matrix[1,], rgb_matrix[2,], rgb_matrix[3,], sep=",")
} else {
tx$color <- "0,0,0"
}

# Build BED12 rows
bed_list <- vector("list", nrow(tx))

for(i in 1:nrow(tx)) {
t <- tx[i]
e <- ex[.(t$transcript_id)][order(start)]

if(nrow(e) == 0) next

# Get CDS boundaries for thick regions
if(nrow(cds) > 0) {
t_cds <- cds[.(t$transcript_id)]
if(nrow(t_cds) > 0) {
thick_start <- min(t_cds$start) - 1
thick_end <- max(t_cds$end)
} else {
# Has exons but no CDS = all thin (UTR only)
thick_start <- t$start - 1
thick_end <- t$start - 1
}
} else {
# No CDS at all = all thin
thick_start <- t$start - 1
thick_end <- t$start - 1
}

# Build block strings
block_sizes <- paste(e$width, collapse=",")
block_starts <- paste(e$start - t$start, collapse=",")

# Build BED12 line
bed_list[[i]] <- c(
as.character(t$seqnames),
as.character(t$start-1),
as.character(t$end),
as.character(t$name),
"0",
as.character(t$strand),
as.character(thick_start),
as.character(thick_end),
as.character(t$color),
as.character(nrow(e)),
block_sizes,
block_starts
)
}

# Remove NULLs and write
bed <- do.call(rbind, bed_list[!sapply(bed_list, is.null)])
write.table(bed, output_bed, sep="\t", quote=F, row.names=F, col.names=F)

# Write filtered GTF
export(filtered_gtf, output_file)
message(paste0("Saved ", length(unique(filtered_gtf$transcript_id)), " transcripts to ", basename(output_file)))
}

# =============================================================================
# Combine sqanti associated gene_ids with gencode ids and hash ids, include counts
# =============================================================================
# =================================================================================
# Combine sqanti associated gene_ids with reference ids and hash ids, include counts
# =================================================================================

# Gencode gtf used to select protein coding genes and add column of readable gene name
# reference gtf used to select protein coding genes and add column of readable gene name
gencode = import(gencode_gtf, format = "gtf")
gencode_df = as.data.frame(gencode)

gencode_df %<>%
filter(type == "transcript") %>%
select(associated_gene = gene_id, gene_type, gene_name, associated_transcript = transcript_id, transcript_name, protein_id)
select(associated_gene = gene_id,
gene_type = any_of(c("gene_type", "gene_biotype")),
gene_name = any_of(c("gene_name", "gene")),
associated_transcript = transcript_id,
any_of("transcript_name"),
any_of("protein_id"))

gencode_gene = gencode_df %>%
select(associated_gene, gene_type, gene_name) %>%
select(associated_gene, any_of("gene_type"), any_of("gene_name")) %>%
distinct()

# Metatable to add specific sample ids
Expand All @@ -127,19 +222,28 @@ sqanti_df %<>%
~ (.x / sum(.x)) * 1e6,
.names = "{gsub('_counts', '_cpm', .col)}"))

# Keep only isoforms with ENS gene IDs
# Keep only annotated genes
sqanti_df_full = sqanti_df %>%
filter(!is.na(associated_gene), str_starts(associated_gene, "ENS"))
filter(!is.na(associated_gene), !str_starts(associated_gene, "novelGene"))
#filter(!is.na(associated_gene), str_starts(associated_gene, "ENS"))

sqanti_df_full %<>%
left_join(gencode_gene, by = "associated_gene") %>%
left_join(gencode_df, by = c("associated_gene", "associated_transcript", "gene_type", "gene_name"))
left_join(gencode_df)

# combine with hashids, and selct cpm columns
hashids = read_tsv(mapping_file)

sqanti_ids = sqanti_df_full %>%
select(isoform_id = isoform, ensg_gene_id = associated_gene, gene_name, enst_transcript_id = associated_transcript, transcript_name, gene_type, protein_id, ends_with("_counts"), ends_with("_cpm"))
select(isoform_id = isoform,
ensg_gene_id = associated_gene,
any_of("gene_name"),
enst_transcript_id = associated_transcript,
any_of("transcript_name"),
any_of("gene_type"),
any_of("protein_id"),
ends_with("_counts"),
ends_with("_cpm"))

all_ids = hashids %>% select(isoform_id = transcript_id, hash_id) %>%
inner_join(sqanti_ids, by = "isoform_id")
Expand All @@ -154,7 +258,7 @@ message("\n Starting filtering of SQANTI output...")

# Initialize dropout tracking
original_ids = sqanti_df_full$isoform
dropout_tracker = list()
sqanti_df_full$dropout_reason = "kept" # Initialize all as kept

# Keep only transcripts of protein coding genes
sqanti_df = sqanti_df_full
Expand All @@ -165,7 +269,8 @@ if (filter_protein_coding) {

kept_ids = sqanti_df$isoform
dropped_ids = setdiff(ids, kept_ids)
dropout_tracker[["not_protein_coding"]] = dropped_ids
sqanti_df_full$dropout_reason[sqanti_df_full$isoform %in% dropped_ids] = "not_protein_coding"
#dropout_tracker[["not_protein_coding"]] = dropped_ids

message("Protein coding filter: kept ", length(kept_ids), " transcripts, dropped ", length(dropped_ids), " transcripts")
}
Expand All @@ -192,7 +297,8 @@ if (filter_internal_priming) {

kept_ids = sqanti_df$isoform
dropped_ids = setdiff(ids, kept_ids)
dropout_tracker[["internal_priming"]] = dropped_ids
sqanti_df_full$dropout_reason[sqanti_df_full$isoform %in% dropped_ids] = "internal_priming"
#dropout_tracker[["internal_priming"]] = dropped_ids

message("Internal priming filter: kept ", length(kept_ids), " transcripts, dropped ", length(dropped_ids), " transcripts")
}
Expand Down Expand Up @@ -223,20 +329,13 @@ if (filter_RTS) {

kept_ids = sqanti_df$isoform
dropped_ids = setdiff(ids, kept_ids)
dropout_tracker[["Template_switching_artifact"]] = dropped_ids
sqanti_df_full$dropout_reason[sqanti_df_full$isoform %in% dropped_ids] = "template_switching"
#dropout_tracker[["Template_switching_artifact"]] = dropped_ids

message("Template switching filter: kept ", length(kept_ids), " transcripts, dropped ", length(dropped_ids), " transcripts")
}

# Apply structural category filter
# STRUCTURAL_CATEGORIES <- list(
# strict = c("novel_not_in_catalog", "novel_in_catalog",
# "incomplete-splice_match", "full-splice_match"),
# all = c("antisense", "novel_not_in_catalog", "novel_in_catalog",
# "incomplete-splice_match", "full-splice_match", "genic",
# "intergenic", "fusion", "genic_intron")
# )

if (toupper(tclass_to_keep) == "ALL") {
# Keep everything - no filtering
ids = sqanti_df$isoform
Expand Down Expand Up @@ -271,21 +370,24 @@ if (toupper(tclass_to_keep) == "ALL") {

#allowed_categories = STRUCTURAL_CATEGORIES[[structural_level]]
dropped_ids = setdiff(ids, kept_ids)
dropout_tracker[["structural_category_filtered"]] = dropped_ids
sqanti_df_full$dropout_reason[sqanti_df_full$isoform %in% dropped_ids] = "structural_category"
#dropout_tracker[["structural_category_filtered"]] = dropped_ids
message("Structural category filter: kept ", length(kept_ids), " transcripts, dropped ", length(dropped_ids), " transcripts")

# Calculate final results
all_dropout_ids = setdiff(original_ids, kept_ids)
all_dropout_ids = sqanti_df_full$isoform[sqanti_df_full$dropout_reason != "kept"]

#all_dropout_ids = setdiff(original_ids, kept_ids)

message(paste0("Final: ", length(kept_ids), " kept, ", length(all_dropout_ids), " dropped"))
#message(paste0("Final: ", length(kept_ids), " kept, ", length(all_dropout_ids), " dropped"))

# =============================================================================
# Write output files and generate summary
# =============================================================================

message("\n Writing output files...")

# Write out filtered files
write_tsv(sqanti_df_full, file.path(dir, paste0(basename, "_classification_all.txt")))
write_tsv(sqanti_df, file.path(dir, paste0(basename, "_classification_filtered.txt")))

save_filtered_fasta(
Expand All @@ -294,52 +396,48 @@ save_filtered_fasta(
file.path(dir, paste0(basename, "_corrected_filtered.fasta"))
)

save_filtered_gtf(
sqanti_gtf,
kept_ids,
file.path(dir, paste0(basename, "_corrected_filtered.gtf"))
)
# gtf and bed
# sample gtf
gtf = import(sqanti_gtf) %>%
as.data.frame()

filtered_gtf = gtf[gtf$transcript_id %in% kept_ids, ]
gtf_ids = filtered_gtf %>% distinct(gene_id, transcript_id)

new_attributes = all_ids %>%
filter(isoform_id %in% kept_ids) %>%
select(transcript_id = isoform_id, everything()) %>%
left_join(gtf_ids) %>%
mutate(avg_cpm = rowMeans(across(contains("cpm")), na.rm = TRUE)) %>%
group_by(gene_id) %>%
mutate(
gene_total_cpm = sum(avg_cpm, na.rm = TRUE),
avg_ratio = round(avg_cpm / gene_total_cpm, 3)
) %>%
ungroup() %>%
select(transcript_id, reference_gene_name = gene_name, avg_ratio)

filtered_gtf %<>%
left_join(new_attributes, by = c("transcript_id")) %>%
mutate(name = paste0(transcript_id, "|",
reference_gene_name, "|",
avg_ratio))

gr_updated = makeGRangesFromDataFrame(filtered_gtf, keep.extra.columns = TRUE)
export(gr_updated, file.path(dir, paste0(basename, "_corrected_filtered.gtf")), format = "gtf")

# convert to bed12
gtf_to_bed12(gtf_path = file.path(dir, paste0(basename, "_corrected_filtered.gtf")),
output_bed = file.path(dir, paste0(basename, "_corrected_filtered.bed")),
color_by = "avg_ratio")

# Filtered hashid and cpm table
all_ids %>% filter(isoform_id %in% sqanti_df$isoform) %>%
write_tsv(file.path(dir, paste0(basename, "_hashids_with_cpm_filtered.txt")))

# Create dropout reasons table and write dropout files
dropout_reasons_df = map_dfr(names(dropout_tracker), ~ {
tibble(
isoform = dropout_tracker[[.x]],
dropout_reason = .x
)
})

dropout_reasons_df %<>%
left_join(sqanti_df_full)

write_tsv(dropout_reasons_df, file.path(dropout_dir, paste0(basename, "_dropout_transcripts.tsv")))

save_filtered_fasta(
sqanti_fasta,
all_dropout_ids,
file.path(dropout_dir, paste0(basename, "_corrected_dropout.fasta"))
)

save_filtered_gtf(
sqanti_gtf,
all_dropout_ids,
file.path(dropout_dir, paste0(basename, "_corrected_dropout.gtf"))
)

# Print summary
cat("\n=== FILTERING SUMMARY ===\n")
cat(paste0("Sample: ", basename, "\n"))
cat(paste0("Original transcripts: ", length(original_ids), "\n"))
cat(paste0("Final transcripts: ", length(kept_ids), "\n"))
cat(paste0("Retention rate: ", round(100 * length(kept_ids) / length(original_ids), 1), "%\n"))
cat("\nDropout breakdown:\n")

# Print dropout breakdown
for (reason in names(dropout_tracker)) {
count = length(dropout_tracker[[reason]])
pct = round(100 * count / length(original_ids), 1)
cat(paste0(" ", reason, ": ", count, " (", pct, "%)\n"))
}
cat(paste0("Retention rate: ", round(100 * length(kept_ids) / length(original_ids), 1), "%\n"))
Loading
Loading