## ----setup, message=FALSE, warning=FALSE--------------------------------------
has_connectivity <- requireNamespace("BiocFileCache", quietly = TRUE) &&
    curl::has_internet()
knitr::opts_chunk$set(dpi = 72, fig.width = 5, fig.height = 4,
                      eval = has_connectivity)
if (!has_connectivity)
    message("No internet connection; vignette code not evaluated.")

library(STADyUM)           # For transcription rate estimation
library(GenomicRanges)     # For genomic data manipulation
library(dplyr)
library(plyranges)
#library(tidyverse)         # For data manipulation and visualization
library(ggplot2)           # For plotting
library(pracma)
library(BiocFileCache)

## ----file-paths---------------------------------------------------------------

bfc <- BiocFileCache(ask = FALSE)
zip_path <- bfcrpath(bfc, "https://zenodo.org/records/20618059/files/rate_estimation_vignette_data.zip")
unzip(zip_path, exdir = tempdir())
data_dir <- file.path(tempdir(), "vignettes", "lrt_test_data")

# TSN (Transcription Start Site Network) files
# These are GRanges objects containing TSS annotations for each sample
cd4_tsn_file <- file.path(data_dir, 'PROseq-HUMAN-CD4_tsn.RDS')

# Transcript annotation file
# Contains gene models with exon structure for all genes
human_tq_file <- file.path(data_dir, 'HUMAN.RDS')

# PRO-seq signal files (stranded bigWig format)
cd4_bw_plus <- file.path(data_dir, 'PROseq-HUMAN-CD4_plus.bw')
cd4_bw_minus <- file.path(data_dir, 'PROseq-HUMAN-CD4_minus.bw')


## ----load-tss-----------------------------------------------------------------
# Load TSN data for both cell types
cd4_tsn <- readRDS(cd4_tsn_file)

# Display structure of TSN data
head(cd4_tsn)

## ----select-single-tss--------------------------------------------------------
# Process CD4 data: select most upstream TSS per gene
cd4_tsn <- STADyUM:::keep_max_tsn(cd4_tsn)

cat("CD4 TSS sites:", length(cd4_tsn), "\n")

## ----load-transcripts---------------------------------------------------------
# Load transcript model for all genes
human_tq <- readRDS(human_tq_file)

# Extract gene ranges and reduce overlapping exons
# This creates a single range per gene (union of all exons)
gngrng <- human_tq@transcripts %>%
  group_by(ensembl_gene_id) %>%
  plyranges::reduce_ranges_directed() %>%
  sort()

cat("Number of genes:", length(gngrng), "\n")

## ----define-count-regions-----------------------------------------------------
# Anchor at 3' end (transcription termination site) and set width to 1
# This creates a point coordinate at the TTS for each gene
bw_tts <- gngrng %>% 
  plyranges::anchor_3p() %>% 
  mutate(width = 1)

# Build read-count regions for CD4
# Returns list with 'pause' and 'gene_body' GRanges
cd4_count_region <- STADyUM:::build_readcount_regions(cd4_tsn, bw_tts)

cat("CD4 pause regions:", length(cd4_count_region$pause), "\n")
cat("CD4 gene body regions:", length(cd4_count_region$gene_body), "\n")

## ----run-stadyum--------------------------------------------------------------
# Estimate transcription rates for CD4
# Input: plus and minus strand bigWig files, pause/gene body regions, sample label
cd4_rate <- estimateTranscriptionRates(
  cd4_bw_plus, cd4_bw_minus,
  cd4_count_region$pause, cd4_count_region$gene_body, 
  "CD4"
)

cat("CD4 rate estimation complete\n")

## ----cd4-beta-chi, fig.width=3.5, fig.height=3--------------------------------

cat("CD4 genes with estimated rates:", nrow(rates(cd4_rate)), "\n")

p1 <- plotScatterDensity(cd4_rate, "betaAdp", "chi",
                          log_x = TRUE, log_y = TRUE,
                          xlab = expression(log[10] ~ beta),
                          ylab = expression(log[10] ~ chi),
                          xlim = c(-6, -1), ylim = c(-4, 1))

if (requireNamespace("ggpubr", quietly = TRUE)) {
  p1 <- p1 + scale_color_viridis_c() +
    ggpubr::stat_cor(method = "spearman", label.x = -3.5, label.y = -3.8,
                     size = 5, aes(label = after_stat(r.label)))
}

print(p1)

## ----cd4-beta-elongation, fig.width=8, fig.height=3---------------------------
p1 <- plotScatterDensity(cd4_rate, "betaAdp", "fkMean",
                          log_x = TRUE,
                          xlab = expression(log[10] ~ beta),
                          ylab = expression(mu))
if (requireNamespace("ggpubr", quietly = TRUE)) {
  p1 <- p1 + scale_color_viridis_c() +
    ggpubr::stat_cor(method = "spearman", label.x = -2.2, label.y = 160,
                     size = 5, aes(label = after_stat(r.label)))
}

p2 <- plotScatterDensity(cd4_rate, "betaAdp", "fkSD",
                          log_x = TRUE,
                          xlab = expression(log[10] ~ beta),
                          ylab = expression(sigma),
                          color_var = "sdGroup",
                          color_values = c(Broad = "#1b9e77", Sharp = "#762a83"),
                          color_lab = expression(sigma ~ group))
if (requireNamespace("ggpubr", quietly = TRUE)) {
  p2 <- p2 +
    ggpubr::stat_cor(method = "spearman", label.x = -2.5, label.y = 75,
                     size = 3, aes(label = after_stat(r.label), group = 1),
                     color = "black")
}
p2_hist <- ggplot(rates(cd4_rate), aes(x = .data$fkSD)) +
  geom_histogram(aes(y = after_stat(density)), bins = 50) +
  geom_vline(xintercept = 38.5, linetype = "dashed", color = "black",
             linewidth = 0.6) +
  coord_flip() +
  theme(axis.title = element_blank(), axis.text = element_blank(),
        axis.ticks = element_blank())
print(p1)

grid::grid.newpage()
grid::pushViewport(grid::viewport(layout = grid::grid.layout(1, 2, widths = grid::unit(c(4, 1), "null"))))
print(p2,      vp = grid::viewport(layout.pos.row = 1, layout.pos.col = 1))
print(p2_hist, vp = grid::viewport(layout.pos.row = 1, layout.pos.col = 2))

p3 <- plotScatterDensity(cd4_rate, "fkMean", "fkSD",
                          xlab = expression(mu),
                          ylab = expression(sigma))
if (requireNamespace("ggpubr", quietly = TRUE)) {
  p3 <- p3 + scale_color_viridis_c() +
    ggpubr::stat_cor(method = "spearman", label.x = 112, label.y = 1,
                     size = 5, aes(label = after_stat(r.label)))
}

print(p3)


## ----session-info-------------------------------------------------------------
sessionInfo()

