## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

library(SummarizedExperiment)
library(SpatialExperiment)
library(SpatialArtifacts)
library(ggplot2)
library(patchwork)
library(dplyr)

## ----eval=FALSE---------------------------------------------------------------
# BiocManager::install("SpatialArtifacts", version="devel")

## ----eval=FALSE---------------------------------------------------------------
# install.packages("BiocManager")
# BiocManager::install("SpatialArtifacts")

## ----eval=FALSE---------------------------------------------------------------
# # Standard Visium (55µm hexagonal grid)
# spe <- detectEdgeArtifacts(spe, platform = "visium", ...)
# 
# # VisiumHD 16µm (square grid)
# spe <- detectEdgeArtifacts(spe, platform = "visiumhd", resolution = "16um", ...)
# 
# # VisiumHD 8µm (square grid)
# spe <- detectEdgeArtifacts(spe, platform = "visiumhd", resolution = "8um", ...)

## ----run_workflow, message=TRUE, warning=FALSE--------------------------------
data(spe_vignette)
# Loaded data dimensions:
dim(spe_vignette)

assay(spe_vignette, "counts") <- as.matrix(assay(spe_vignette, "counts"))
names(colData(spe_vignette))[names(colData(spe_vignette)) == "sum"] <- "sum_umi"

spe_detected <- detectEdgeArtifacts(
  spe_vignette,
  platform = "visium", # IMPORTANT: Specify Standard Visium platform
  qc_metric = "sum_umi",
  samples = "sample_id",
  batch_var = "sample_id",
  mad_threshold = 3,
  edge_threshold = 0.75,
  name = "edge_artifact"
)

# === RESULTS ===
table(Edge_Detected = spe_detected$edge_artifact_edge)

# Classification with Standard Visium parameters
spe_classified <- classifyEdgeArtifacts(
  spe_detected,
  min_spots = 20,
  name = "edge_artifact"
)

# === Classification Results ===
table(spe_classified$edge_artifact_classification)

## ----visiumhd_16um_example, eval=FALSE----------------------------------------
# # This is a pseudo-example demonstrating VisiumHD 16µm workflow
# # Assumes you have loaded a VisiumHD SpatialExperiment object as 'spe_hd16'
# 
# # Step 1: Ensure required QC metrics are calculated
# library(scuttle)
# spe_hd16 <- addPerCellQCMetrics(spe_hd16)
# 
# # Step 2: Detection Phase - VisiumHD uses square grid (no 'shifted' needed)
# spe_hd16_detected <- detectEdgeArtifacts(
#   spe_hd16,
#   platform = "visiumhd", # Specify VisiumHD platform
#   resolution = "16um", # REQUIRED for VisiumHD
#   qc_metric = "sum_umi", # or "sum" depending on your colData
#   samples = "sample_id",
#   buffer_width_um = 100, # VisiumHD-specific parameter
#   mad_threshold = 2.5,
#   edge_threshold = 0.75,
#   name = "edge_artifact"
# )
# 
# # Step 3: Classification Phase - CRITICAL: Scale min_spots for VisiumHD resolution
# # For 16µm bins, use ~6-10× the Standard Visium threshold
# min_spots_16um <- 30 * (55 / 16)^2 # ≈ 354 bins
# 
# spe_hd16_classified <- classifyEdgeArtifacts(
#   spe_hd16_detected,
#   qc_metric = "sum_umi",
#   min_spots = round(min_spots_16um), # ~350 bins
#   name = "edge_artifact"
# )
# 
# # Visualization (same approach as Standard Visium)
# table(spe_hd16_classified$edge_artifact_classification)

## ----visiumhd_8um_example, eval=FALSE-----------------------------------------
# # This is a pseudo-example demonstrating VisiumHD 8µm workflow
# # Assumes you have loaded a VisiumHD 8µm SpatialExperiment object as 'spe_hd8'
# 
# # Step 1: QC metrics
# spe_hd8 <- addPerCellQCMetrics(spe_hd8)
# 
# # Step 2: Detection Phase
# spe_hd8_detected <- detectEdgeArtifacts(
#   spe_hd8,
#   platform = "visiumhd", # Specify VisiumHD platform
#   resolution = "8um", # REQUIRED: Specify 8µm resolution
#   qc_metric = "sum_umi",
#   samples = "sample_id",
#   buffer_width_um = 100, # Buffer zone in micrometers
#   mad_threshold = 2.5,
#   edge_threshold = 0.75,
#   name = "edge_artifact"
# )
# 
# # Step 3: Classification with 8µm-appropriate threshold
# # For 8µm bins, use ~20-40× the Standard Visium threshold
# min_spots_8um <- 30 * (55 / 8)^2 # ≈ 1,420 bins
# 
# spe_hd8_classified <- classifyEdgeArtifacts(
#   spe_hd8_detected,
#   qc_metric = "sum_umi",
#   min_spots = round(min_spots_8um), # ~1,400 bins
#   name = "edge_artifact"
# )
# 
# table(spe_hd8_classified$edge_artifact_classification)

## ----plot, fig.width=12, fig.height=16----------------------------------------
library(SpatialExperiment)
library(patchwork)

plot_data <- as.data.frame(colData(spe_classified))
plot_data <- cbind(plot_data, as.data.frame(spatialCoords(spe_classified)))
plot_data_in_tissue <- plot_data[plot_data$in_tissue, ]

base_theme <- theme_void() +
  theme(plot.title = element_text(size = 12, hjust = 0.5), legend.position = "right")

.plt <- \(df, col, fun = \(.) .) {
  ggplot(df, aes(x = pxl_col_in_fullres, y = pxl_row_in_fullres, col = fun(.data[[col]]))) +
    geom_point(size = 0.5) +
    base_theme +
    coord_fixed()
}

plot_data_in_tissue$raw_problem <- !is.na(plot_data_in_tissue$edge_artifact_problem_id)
plot_data_in_tissue$cluster_display <- plot_data_in_tissue$edge_artifact_problem_id
plot_data_in_tissue$artifact_type <- "Normal"
plot_data_in_tissue$artifact_type[plot_data_in_tissue$edge_artifact_edge] <- "Edge Artifact"
plot_data_in_tissue$artifact_type[!is.na(plot_data_in_tissue$edge_artifact_problem_id) & !plot_data_in_tissue$edge_artifact_edge] <- "Interior Problem"

p1 <- .plt(plot_data_in_tissue, "sum_umi", \(.) log10(.+1)) + 
  scale_color_viridis_c(name = "log10(UMI)") + ggtitle("A. UMI Counts")

p2 <- .plt(plot_data_in_tissue, "detected") + 
  scale_color_viridis_c(name = "Genes", option = "plasma") + ggtitle("B. Detected Genes")

p3 <- .plt(plot_data_in_tissue, "raw_problem") + 
  scale_color_manual(values = c("FALSE" = "lightgray", "TRUE" = "red"), name = "Problem?") + ggtitle("C. Raw Detection")

p4 <- .plt(plot_data_in_tissue, "cluster_display") + 
  scale_color_discrete(name = "Cluster ID", na.value = "lightgray") + ggtitle("D. Problem Area Clusters") + theme(legend.key.size = unit(0.3, "cm"), legend.text = element_text(size = 8))

p5 <- .plt(plot_data_in_tissue, "artifact_type") + 
  scale_color_manual(values = c("Normal" = "lightgray", "Edge Artifact" = "red", "Interior Problem" = "blue"), name = "Type") + ggtitle("E. Edge vs Interior")

p6 <- .plt(plot_data_in_tissue, "edge_artifact_classification") + 
  scale_color_manual(values = c("not_artifact" = "lightgray", "large_edge_artifact" = "red", "small_edge_artifact" = "orange", "large_interior_artifact" = "blue", "small_interior_artifact" = "cyan"), name = "Final Class", na.value = "grey50") + ggtitle("F. Hierarchical Classification")

(p1 | p2) / (p3 | p4) / (p5 | p6)

## ----classification, echo=FALSE-----------------------------------------------
# Use the 'spe_classified' object we created in the step above
final_summary <- table(spe_classified$edge_artifact_classification)
final_pct <- round(100 * final_summary / sum(final_summary), 2)
final_df <- data.frame(
  Classification = names(final_summary),
  Count = as.numeric(final_summary),
  Percentage = paste0(as.numeric(final_pct), "%")
)
knitr::kable(final_df, caption = "Classification Breakdown")

## ----raw_edge, echo=FALSE-----------------------------------------------------
# Use the 'spe_classified' object and the new column name
edge_summary <- table(spe_classified$edge_artifact_edge)
edge_pct <- round(100 * edge_summary / sum(edge_summary), 2)
edge_df <- data.frame(
  Flagged_As_Edge = names(edge_summary),
  Count = as.numeric(edge_summary),
  Percentage = paste0(as.numeric(edge_pct), "%")
)
knitr::kable(edge_df, caption = "Raw Detection Breakdown")

## ----validation, echo=FALSE, warning=FALSE------------------------------------
in_tissue_data <- spe_classified[, spe_classified$in_tissue]
qc_data <- data.frame(
  sum_umi = in_tissue_data$sum_umi,
  detected_genes = in_tissue_data$detected,
  flagged = in_tissue_data$edge_artifact_edge
)

flagged_umi <- median(qc_data$sum_umi[qc_data$flagged], na.rm = TRUE)
nonflagged_umi <- median(qc_data$sum_umi[!qc_data$flagged], na.rm = TRUE)

flagged_gene <- median(qc_data$detected_genes[qc_data$flagged], na.rm = TRUE)
nonflagged_gene <- median(qc_data$detected_genes[!qc_data$flagged], na.rm = TRUE)

qc_summary <- data.frame(
  Metric = c("Median UMI", "Median Detected Genes"),
  `Flagged (Edge)` = c(round(flagged_umi), round(flagged_gene)),
  `Non-flagged` = c(round(nonflagged_umi), round(nonflagged_gene)),
  Difference = c(round(nonflagged_umi - flagged_umi), round(nonflagged_gene - flagged_gene)),
  check.names = FALSE
)

knitr::kable(qc_summary, caption = "QC Validation: Flagged vs Non-flagged Spots")

qc_data$flag_status <- ifelse(qc_data$flagged, "Flagged (Raw Edge)", "Non-Flagged")

validation_plot <- ggplot(qc_data, aes(x = flag_status, y = log10(sum_umi + 1), fill = flag_status)) +
  geom_boxplot() +
  scale_fill_manual(values = c("Flagged (Raw Edge)" = "lightcoral", "Non-Flagged" = "lightblue")) +
  labs(
    title = "QC Validation: UMI Counts in Raw Edge vs Non-Edge Spots",
    x = "Raw Edge Detection Status", y = "log10(UMI Count + 1)"
  ) +
  theme_minimal()

print(validation_plot)

## ----filtering----------------------------------------------------------------
if ("edge_artifact_classification" %in% names(colData(spe_classified))) {
  spots_to_keep <- !spe_classified$edge_artifact_classification %in% 
    c("large_edge_artifact", "small_edge_artifact")
  spe_filtered <- spe_classified[, spots_to_keep]
  message("Original number of spots: ", ncol(spe_classified))
  message("Number of spots after filtering: ", ncol(spe_filtered))
} else {
  message("Classification column not found. Filtering step skipped.")
}

## ----plotFiltering, fig.width=12, fig.height=6, message=FALSE, warning=FALSE----
plot_data_before <- as.data.frame(colData(spe_classified))
plot_data_before <- cbind(plot_data_before, as.data.frame(spatialCoords(spe_classified)))
plot_data_before_in_tissue <- plot_data_before[plot_data_before$in_tissue, ]

plot_data_after <- as.data.frame(colData(spe_filtered))
if (ncol(spe_filtered) > 0) {
  plot_data_after <- cbind(plot_data_after, as.data.frame(spatialCoords(spe_filtered)))
}

p1_umi_before <- .plt(plot_data_before_in_tissue, "sum_umi", \(.) log10(.+1)) +
  scale_color_viridis_c(name = "log10(UMI)") + ggtitle("UMI Counts (Before Filtering)")

p3_class_before <- .plt(plot_data_before_in_tissue, "edge_artifact_classification") +
  scale_color_manual(values = c("not_artifact" = "lightgray", "large_edge_artifact" = "red", "small_edge_artifact" = "orange", "large_interior_artifact" = "blue", "small_interior_artifact" = "cyan"), name = "Final Class", na.value = "grey50", drop = FALSE) + ggtitle("Classification (Before Filtering)")

if (ncol(spe_filtered) > 0) {
  p2_umi_after <- .plt(plot_data_after, "sum_umi", \(.) log10(.+1)) +
    scale_color_viridis_c(name = "log10(UMI)") + ggtitle("UMI Counts (After Filtering)")
    
  p4_class_after <- .plt(plot_data_after, "edge_artifact_classification") +
    scale_color_manual(values = c("not_artifact" = "lightgray", "large_edge_artifact" = "red", "small_edge_artifact" = "orange", "large_interior_artifact" = "blue", "small_interior_artifact" = "cyan"), name = "Final Class", na.value = "grey50", drop = FALSE) + ggtitle("Classification (After Filtering)")
} else {
  p2_umi_after <- ggplot() + theme_void() + ggtitle("UMI Counts (After Filtering - No Spots)")
  p4_class_after <- ggplot() + theme_void() + ggtitle("Classification (After Filtering - No Spots)")
}

combined_filtering_plot_2x2 <- (p1_umi_before | p2_umi_after) / (p3_class_before | p4_class_after)
print(combined_filtering_plot_2x2)

## ----session_info-------------------------------------------------------------
sessionInfo()

