@@ -34,8 +34,8 @@ library(viridis)
3434
3535## What is a Generalized Additive Model (GAM)?
3636
37- - ** GAMs** are a class of statistical models used to model complex, non-linear relationships between the response and the predictors.
38- - ** Key idea** : GAMs model the mean of the response variable as a sum of smooth functions of the predictors.
37+ - ** GAMs** are a class of statistical models used to model complex, non-linear relationships between the response and the predictors
38+ - ** Key idea** : GAMs model the mean of the response variable as a sum of smooth functions of the predictors
3939
4040## GAM Formula
4141
@@ -79,7 +79,7 @@ library(viridis)
7979
8080Very powerful R package for fitting GAMs
8181
82- Univariate smooth terms are expressed as
82+ Univariate smooth terms are expressed with ` s() `
8383
8484``` {r echo = TRUE, eval = FALSE}
8585gam(y ~ s(x1, k = 10, bs = "cr") +
@@ -97,7 +97,7 @@ We'll cover more complicated smooths later
9797- ** Some Types of Splines** :
9898 - ** B-splines (Basis Splines)** : Commonly used due to computational efficiency and flexibility
9999 - ** Cubic Splines** : Splines with cubic polynomials between each pair of adjacent knots
100- - ** Thin-plate Splines** : A generalization of B-splines, used for higher-dimensional data
100+ - ** Thin-plate Splines** : A generalization of B-splines, often used for higher-dimensional data
101101
102102## What Are Knots in Splines?
103103
@@ -170,7 +170,110 @@ To understand how the weighted sum of basis functions creates a smooth function,
170170weights <- c(0.4, -0.5, 0.3, 0.1, 0.6, -0.2) # Example weights for each basis function
171171
172172# Calculate the weighted sum of basis functions
173- weighted_spline <- rowSums(spl * weights)
173+ weighted_spline <- spl %*% matrix(weights, ncol=1)
174+
175+ # data frame for plotting the basis functions and the weighted spline
176+ df_splines_weighted <- data.frame(x = rep(x, ncol(spl)),
177+ "b" = sort(rep(1:ncol(spl), nrow(spl))),
178+ "basis_function" = c(spl))
179+
180+ # Data frame for the weighted spline
181+ df_weighted_spline <- data.frame(x = x,
182+ weighted_spline = weighted_spline)
183+
184+ # Plot both individual basis functions and the weighted sum
185+ p3 <- ggplot() +
186+ geom_line(data = df_splines_weighted, aes(x, basis_function, group = b, col = as.factor(b)), size = 1.2) +
187+ geom_line(data = df_weighted_spline, aes(x, weighted_spline), col = "blue", size = 1.2) +
188+ theme_bw() +
189+ xlab("X") +
190+ ylab("Function Value") +
191+ scale_color_viridis_d(end = 0.8) +
192+ theme(legend.position = "none") +
193+ ggtitle("Basis Functions and Weighted Spline")
194+ print(p3)
195+ ```
196+
197+ ## What if we make the basis dimension smaller?
198+
199+ ` weights <- c(0.2, -0.1, 0.3) `
200+
201+ ``` {r echo=FALSE, warning=FALSE, message=FALSE, fig.height=4, fig.width=7}
202+
203+ # Define weights for the basis functions
204+ weights <- c(0.2, -0.1, 0.3) # Example weights for each basis function
205+ spl <- splines::bs(x, df = 3)
206+
207+ # Calculate the weighted sum of basis functions
208+ weighted_spline <- spl %*% matrix(weights, ncol=1)
209+
210+ # data frame for plotting the basis functions and the weighted spline
211+ df_splines_weighted <- data.frame(x = rep(x, ncol(spl)),
212+ "b" = sort(rep(1:ncol(spl), nrow(spl))),
213+ "basis_function" = c(spl))
214+
215+ # Data frame for the weighted spline
216+ df_weighted_spline <- data.frame(x = x,
217+ weighted_spline = weighted_spline)
218+
219+ # Plot both individual basis functions and the weighted sum
220+ p3 <- ggplot() +
221+ geom_line(data = df_splines_weighted, aes(x, basis_function, group = b, col = as.factor(b)), size = 1.2) +
222+ geom_line(data = df_weighted_spline, aes(x, weighted_spline), col = "blue", size = 1.2) +
223+ theme_bw() +
224+ xlab("X") +
225+ ylab("Function Value") +
226+ scale_color_viridis_d(end = 0.8) +
227+ theme(legend.position = "none") +
228+ ggtitle("Basis Functions and Weighted Spline")
229+ print(p3)
230+ ```
231+
232+ ## What if we make the basis dimension smaller?
233+
234+ ` weights <- c(0.6, -0.1, 0.01) `
235+
236+ ``` {r echo=FALSE, warning=FALSE, message=FALSE, fig.height=4, fig.width=7}
237+
238+ # Define weights for the basis functions
239+ weights <- c(0.6, -0.1, 0.01) # Example weights for each basis function
240+ spl <- splines::bs(x, df = 3)
241+
242+ # Calculate the weighted sum of basis functions
243+ weighted_spline <- spl %*% matrix(weights, ncol=1)
244+
245+ # data frame for plotting the basis functions and the weighted spline
246+ df_splines_weighted <- data.frame(x = rep(x, ncol(spl)),
247+ "b" = sort(rep(1:ncol(spl), nrow(spl))),
248+ "basis_function" = c(spl))
249+
250+ # Data frame for the weighted spline
251+ df_weighted_spline <- data.frame(x = x,
252+ weighted_spline = weighted_spline)
253+
254+ # Plot both individual basis functions and the weighted sum
255+ p3 <- ggplot() +
256+ geom_line(data = df_splines_weighted, aes(x, basis_function, group = b, col = as.factor(b)), size = 1.2) +
257+ geom_line(data = df_weighted_spline, aes(x, weighted_spline), col = "blue", size = 1.2) +
258+ theme_bw() +
259+ xlab("X") +
260+ ylab("Function Value") +
261+ scale_color_viridis_d(end = 0.8) +
262+ theme(legend.position = "none") +
263+ ggtitle("Basis Functions and Weighted Spline")
264+ print(p3)
265+ ```
266+
267+ ## What if we make the basis dimension larger?
268+
269+ ``` {r echo=FALSE, warning=FALSE, message=FALSE, fig.height=4, fig.width=7}
270+
271+ # Define weights for the basis functions
272+ weights <- c(0.4, -0.5, 0.3, 0.1, 0.6, -0.2, 0.4, -0.1, 0.3, -0.3, 0.2, 0.1) # Example weights for each basis function
273+ spl <- splines::bs(x, df = 12)
274+
275+ # Calculate the weighted sum of basis functions
276+ weighted_spline <- spl %*% matrix(weights, ncol=1)
174277
175278# data frame for plotting the basis functions and the weighted spline
176279df_splines_weighted <- data.frame(x = rep(x, ncol(spl)),
@@ -219,11 +322,12 @@ X_spline <- sc[[1]]$X
219322- Not exactly the same as mgcv because mgcv also uses penalty matrix
220323
221324``` {r eval=FALSE, echo=TRUE}
325+ # simulate sinusoidal data + obs error
222326y <- sin(2 * pi * x) + rnorm(100, sd = 0.2)
327+ # spline regression
223328fit <- lm(y ~ X_spline)
224329```
225330
226-
227331## Key Take Homes
228332
229333- ** It doesn't take many basis functions to create a flexible spline**
@@ -341,6 +445,12 @@ ggplot() +
341445 ggtitle("Knots Placement and Predictions")
342446```
343447
448+ ## When does knot placement matter?
449+
450+ - ** When the data is unevenly distributed** : If the data is concentrated in certain regions, more knots can be placed in high density regions
451+ - ** When the relationship is non-linear** : If the relationship between the predictor and response variable is highly non-linear, adding knots can help the spline fit better
452+ - ** When the response data has abrupt changes** : If the data is affected by regimes / change points / big changes or discontinuities
453+
344454## Common Smooth Types in GAMs
345455
346456- ** Thin Plate Splines (` tp ` )** : Flexible, non-linear smooths for complex, irregular data
@@ -364,6 +474,29 @@ ggplot() +
364474- ** Incorporating Trend and Noise**
365475 - Smooth functions can represent underlying trends, while residuals capture random noise or irregularities
366476
477+ ## GAMs and DLMs:
478+
479+ * Example: daily shad counts on the Columbia River (Bonneville)
480+
481+ ``` {r echo=FALSE, warning=FALSE, message=FALSE, fig.height=4, fig.width=7}
482+ shad <- read.csv("shad.csv")
483+ shad$year <- lubridate::year(shad$date)
484+
485+ shad22_24 <- dplyr::filter(shad, year > 2021)
486+ shad22_24$jday <- seq(1, nrow(shad22_24))
487+ shad22_24 |>
488+ ggplot(aes(date, log(shad))) + geom_point() +
489+ ggtitle("Shad counts 2022 - 2024") + theme_bw() + ylab("ln (Shad)")
490+ ```
491+
492+ ## GAMs and DLMs:
493+
494+ * Use smooth to fill in the gaps
495+ * ` cor(obs, pred) ` ~ 0.86
496+ ``` {r}
497+ fit <- gam(log(shad) ~ as.factor(year) + s(jday, bs = "cr"),
498+ data = shad22_24)
499+ ```
367500
368501## GAMs and DLMs:
369502
@@ -404,6 +537,7 @@ gridExtra::grid.arrange(p1, p2, ncol = 1)
404537
405538* What is this model doing? Where have we seen something similar??
406539``` {r echo = TRUE, eval = TRUE, fig.height=3, fig.width=6}
540+ # personal savings = s(time)
407541fit <- gam(psavert ~ s(time_num, bs = "cr"),
408542 data = economics)
409543```
@@ -423,13 +557,16 @@ plot(fit)
423557* How about a smooth / random walk on the covariate. Is this correct and or why not?
424558
425559``` {r echo=TRUE, eval=TRUE}
560+ # savings rate ~ s(unemployment)
426561fit <- gam(psavert ~ s(ln_unemploy, bs = "cr"),
427562 data = economics)
428563```
429564
430565## GAMs and DLMs:
431566
432- * This is fitting a non-linear smooth of CUI (totally ignoring time)
567+ * This is fitting a flexible non-linear model
568+
569+ * However, non-linear smooth of unemployment totally ignores time aspect
433570
434571``` {r echo = TRUE, eval = TRUE, fig.height=3, fig.width=6}
435572plot(fit)
@@ -442,17 +579,44 @@ plot(fit)
442579* The prior model was fitting a non-linear smooth of the covariate
443580
444581* What about a smooth of the covariate and time?
582+
583+ * Here we add a 2D smooth, ` s(ln_unemploy, time_num) `
445584``` {r echo=TRUE, eval=TRUE}
446585fit <- gam(psavert ~ s(ln_unemploy, time_num),
447586 data = economics)
448587```
449588
450589## GAMs and DLMs:
451590
452- * This is fitting 2D smooth of ln_unemploy and time
591+ * What is the 2D smooth of ln_unemploy and time doing?
592+
593+ * There's maybe some slight non-linearity here
453594
454595``` {r echo = TRUE, eval = TRUE, fig.height=3, fig.width=6}
455- plot(fit)
596+ vis.gam(fit, view = c("time_num", "ln_unemploy"), plot.type = "contour", color = "heat")
597+ ```
598+
599+ ## GAMs and DLMs
600+
601+ * In some cases, the 2D smooth is more complciated
602+
603+ ``` {r echo = TRUE, eval = TRUE, fig.height=3, fig.width=6}
604+
605+ # Simulate predictors
606+ set.seed(123)
607+ n <- 400
608+ x1 <- runif(n, 0, 10)
609+ x2 <- runif(n, 0, 10)
610+
611+ # Create a response with interaction
612+ # The response surface is non-additive: x1's effect depends on x2
613+ y <- sin(x1) * cos(x2) + rnorm(n, sd = 0.3)
614+
615+ dat <- data.frame(x1 = x1, x2 = x2, y = y)
616+
617+ model <- gam(y ~ s(x1, x2), data = dat)
618+
619+ vis.gam(model, view = c("x1", "x2"), plot.type = "contour", color = "topo")
456620```
457621
458622## GAMs and DLMs
@@ -465,15 +629,16 @@ plot(fit)
465629
466630* Another way to model the interaction is with the ` by ` covariate
467631
468- * This creates separate smooths for different levels of the ` by ` covariate
632+ * This lets the smooth vary / creates separate smooths for different values of the ` by ` covariate
469633
470634``` {r echo=TRUE, eval=FALSE}
471635s(predictor, by = covariate)
472636```
473637
474638## GAMs and DLMs:
475639
476- * For time series data, it's a common mistake to use
640+ * For time series data, a common mistake is to use
641+
477642``` {r echo=TRUE, eval=TRUE}
478643fit <- gam(psavert ~ s(ln_unemploy, by = time_num),
479644 data = economics)
@@ -507,6 +672,13 @@ fit <- gam(psavert ~ s(time_num, by = ln_unemploy),
507672
508673## GAMs and DLMs:
509674
675+ ``` {r echo = FALSE, eval = TRUE, fig.height=3, fig.width=6}
676+ df <- data.frame("Goal" = c("Effect of x changes over time", "Effect of time changes with x"), "Formula" = c("s(time, by = x)", "s(x, by = time)"))
677+ knitr::kable(df)
678+ ```
679+
680+ ## GAMs and DLMs:
681+
510682* Now we have a time-varying slope
511683``` {r echo = TRUE, eval = TRUE, fig.height=3, fig.width=6}
512684plot(fit)
@@ -600,7 +772,7 @@ rbind(dat1[554:574,], dat2[554:574,]) |>
600772
601773* ` compare_gams ` Available in slide Rmd
602774
603- ``` {r echo=FALSE}
775+ ``` {r echo=FALSE, warning=FALSE, message=FALSE }
604776compare_gams <- function(gam1, gam2, data, response, model_names = c("Non-linear GAM", "Time-varying GAM")) {
605777 models <- list(gam1, gam2)
606778 results <- list()
@@ -642,7 +814,7 @@ compare_gams <- function(gam1, gam2, data, response, model_names = c("Non-linear
642814
643815## Non-linearity vs non-stationarity
644816
645- ``` {r echo=TRUE, message=TRUE , warning=TRUE , results='asis'}
817+ ``` {r echo=TRUE, message=FALSE , warning=FALSE , results='asis'}
646818k <- compare_gams(fit_nonlinear, fit_dlm, response="psavert", data = economics)
647819print(knitr::kable(k, digits = 3, caption = "Comparison of GAMs"))
648820```
@@ -655,8 +827,10 @@ print(knitr::kable(k, digits = 3, caption = "Comparison of GAMs"))
655827
656828``` {r echo = TRUE}
657829data("SalmonSurvCUI")
830+ # time varying intercept model
658831g1 <- gam(logit.s ~ s(year, k = 10),
659832 data = SalmonSurvCUI)
833+ # time varying slope model
660834g2 <- gam(logit.s ~ s(year, by = CUI.apr, k = 10),
661835 data = SalmonSurvCUI)
662836```
0 commit comments