Skip to content

Commit c6f2837

Browse files
committed
update GAM stuff for Tuesday
1 parent b92ea53 commit c6f2837

File tree

4 files changed

+9637
-61
lines changed

4 files changed

+9637
-61
lines changed

docs/Lectures/Week 6/lec_11_intro_to_GAMs.Rmd

Lines changed: 187 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ library(viridis)
3434

3535
## What is a Generalized Additive Model (GAM)?
3636

37-
- **GAMs** are a class of statistical models used to model complex, non-linear relationships between the response and the predictors.
38-
- **Key idea**: GAMs model the mean of the response variable as a sum of smooth functions of the predictors.
37+
- **GAMs** are a class of statistical models used to model complex, non-linear relationships between the response and the predictors
38+
- **Key idea**: GAMs model the mean of the response variable as a sum of smooth functions of the predictors
3939

4040
## GAM Formula
4141

@@ -79,7 +79,7 @@ library(viridis)
7979

8080
Very powerful R package for fitting GAMs
8181

82-
Univariate smooth terms are expressed as
82+
Univariate smooth terms are expressed with `s()`
8383

8484
```{r echo = TRUE, eval = FALSE}
8585
gam(y ~ s(x1, k = 10, bs = "cr") +
@@ -97,7 +97,7 @@ We'll cover more complicated smooths later
9797
- **Some Types of Splines**:
9898
- **B-splines (Basis Splines)**: Commonly used due to computational efficiency and flexibility
9999
- **Cubic Splines**: Splines with cubic polynomials between each pair of adjacent knots
100-
- **Thin-plate Splines**: A generalization of B-splines, used for higher-dimensional data
100+
- **Thin-plate Splines**: A generalization of B-splines, often used for higher-dimensional data
101101

102102
## What Are Knots in Splines?
103103

@@ -170,7 +170,110 @@ To understand how the weighted sum of basis functions creates a smooth function,
170170
weights <- c(0.4, -0.5, 0.3, 0.1, 0.6, -0.2) # Example weights for each basis function
171171
172172
# Calculate the weighted sum of basis functions
173-
weighted_spline <- rowSums(spl * weights)
173+
weighted_spline <- spl %*% matrix(weights, ncol=1)
174+
175+
# data frame for plotting the basis functions and the weighted spline
176+
df_splines_weighted <- data.frame(x = rep(x, ncol(spl)),
177+
"b" = sort(rep(1:ncol(spl), nrow(spl))),
178+
"basis_function" = c(spl))
179+
180+
# Data frame for the weighted spline
181+
df_weighted_spline <- data.frame(x = x,
182+
weighted_spline = weighted_spline)
183+
184+
# Plot both individual basis functions and the weighted sum
185+
p3 <- ggplot() +
186+
geom_line(data = df_splines_weighted, aes(x, basis_function, group = b, col = as.factor(b)), size = 1.2) +
187+
geom_line(data = df_weighted_spline, aes(x, weighted_spline), col = "blue", size = 1.2) +
188+
theme_bw() +
189+
xlab("X") +
190+
ylab("Function Value") +
191+
scale_color_viridis_d(end = 0.8) +
192+
theme(legend.position = "none") +
193+
ggtitle("Basis Functions and Weighted Spline")
194+
print(p3)
195+
```
196+
197+
## What if we make the basis dimension smaller?
198+
199+
`weights <- c(0.2, -0.1, 0.3)`
200+
201+
```{r echo=FALSE, warning=FALSE, message=FALSE, fig.height=4, fig.width=7}
202+
203+
# Define weights for the basis functions
204+
weights <- c(0.2, -0.1, 0.3) # Example weights for each basis function
205+
spl <- splines::bs(x, df = 3)
206+
207+
# Calculate the weighted sum of basis functions
208+
weighted_spline <- spl %*% matrix(weights, ncol=1)
209+
210+
# data frame for plotting the basis functions and the weighted spline
211+
df_splines_weighted <- data.frame(x = rep(x, ncol(spl)),
212+
"b" = sort(rep(1:ncol(spl), nrow(spl))),
213+
"basis_function" = c(spl))
214+
215+
# Data frame for the weighted spline
216+
df_weighted_spline <- data.frame(x = x,
217+
weighted_spline = weighted_spline)
218+
219+
# Plot both individual basis functions and the weighted sum
220+
p3 <- ggplot() +
221+
geom_line(data = df_splines_weighted, aes(x, basis_function, group = b, col = as.factor(b)), size = 1.2) +
222+
geom_line(data = df_weighted_spline, aes(x, weighted_spline), col = "blue", size = 1.2) +
223+
theme_bw() +
224+
xlab("X") +
225+
ylab("Function Value") +
226+
scale_color_viridis_d(end = 0.8) +
227+
theme(legend.position = "none") +
228+
ggtitle("Basis Functions and Weighted Spline")
229+
print(p3)
230+
```
231+
232+
## What if we make the basis dimension smaller?
233+
234+
`weights <- c(0.6, -0.1, 0.01)`
235+
236+
```{r echo=FALSE, warning=FALSE, message=FALSE, fig.height=4, fig.width=7}
237+
238+
# Define weights for the basis functions
239+
weights <- c(0.6, -0.1, 0.01) # Example weights for each basis function
240+
spl <- splines::bs(x, df = 3)
241+
242+
# Calculate the weighted sum of basis functions
243+
weighted_spline <- spl %*% matrix(weights, ncol=1)
244+
245+
# data frame for plotting the basis functions and the weighted spline
246+
df_splines_weighted <- data.frame(x = rep(x, ncol(spl)),
247+
"b" = sort(rep(1:ncol(spl), nrow(spl))),
248+
"basis_function" = c(spl))
249+
250+
# Data frame for the weighted spline
251+
df_weighted_spline <- data.frame(x = x,
252+
weighted_spline = weighted_spline)
253+
254+
# Plot both individual basis functions and the weighted sum
255+
p3 <- ggplot() +
256+
geom_line(data = df_splines_weighted, aes(x, basis_function, group = b, col = as.factor(b)), size = 1.2) +
257+
geom_line(data = df_weighted_spline, aes(x, weighted_spline), col = "blue", size = 1.2) +
258+
theme_bw() +
259+
xlab("X") +
260+
ylab("Function Value") +
261+
scale_color_viridis_d(end = 0.8) +
262+
theme(legend.position = "none") +
263+
ggtitle("Basis Functions and Weighted Spline")
264+
print(p3)
265+
```
266+
267+
## What if we make the basis dimension larger?
268+
269+
```{r echo=FALSE, warning=FALSE, message=FALSE, fig.height=4, fig.width=7}
270+
271+
# Define weights for the basis functions
272+
weights <- c(0.4, -0.5, 0.3, 0.1, 0.6, -0.2, 0.4, -0.1, 0.3, -0.3, 0.2, 0.1) # Example weights for each basis function
273+
spl <- splines::bs(x, df = 12)
274+
275+
# Calculate the weighted sum of basis functions
276+
weighted_spline <- spl %*% matrix(weights, ncol=1)
174277
175278
# data frame for plotting the basis functions and the weighted spline
176279
df_splines_weighted <- data.frame(x = rep(x, ncol(spl)),
@@ -219,11 +322,12 @@ X_spline <- sc[[1]]$X
219322
- Not exactly the same as mgcv because mgcv also uses penalty matrix
220323

221324
```{r eval=FALSE, echo=TRUE}
325+
# simulate sinusoidal data + obs error
222326
y <- sin(2 * pi * x) + rnorm(100, sd = 0.2)
327+
# spline regression
223328
fit <- lm(y ~ X_spline)
224329
```
225330

226-
227331
## Key Take Homes
228332

229333
- **It doesn't take many basis functions to create a flexible spline**
@@ -341,6 +445,12 @@ ggplot() +
341445
ggtitle("Knots Placement and Predictions")
342446
```
343447

448+
## When does knot placement matter?
449+
450+
- **When the data is unevenly distributed**: If the data is concentrated in certain regions, more knots can be placed in high density regions
451+
- **When the relationship is non-linear**: If the relationship between the predictor and response variable is highly non-linear, adding knots can help the spline fit better
452+
- **When the response data has abrupt changes**: If the data is affected by regimes / change points / big changes or discontinuities
453+
344454
## Common Smooth Types in GAMs
345455

346456
- **Thin Plate Splines (`tp`)**: Flexible, non-linear smooths for complex, irregular data
@@ -364,6 +474,29 @@ ggplot() +
364474
- **Incorporating Trend and Noise**
365475
- Smooth functions can represent underlying trends, while residuals capture random noise or irregularities
366476

477+
## GAMs and DLMs:
478+
479+
* Example: daily shad counts on the Columbia River (Bonneville)
480+
481+
```{r echo=FALSE, warning=FALSE, message=FALSE, fig.height=4, fig.width=7}
482+
shad <- read.csv("shad.csv")
483+
shad$year <- lubridate::year(shad$date)
484+
485+
shad22_24 <- dplyr::filter(shad, year > 2021)
486+
shad22_24$jday <- seq(1, nrow(shad22_24))
487+
shad22_24 |>
488+
ggplot(aes(date, log(shad))) + geom_point() +
489+
ggtitle("Shad counts 2022 - 2024") + theme_bw() + ylab("ln (Shad)")
490+
```
491+
492+
## GAMs and DLMs:
493+
494+
* Use smooth to fill in the gaps
495+
* `cor(obs, pred)` ~ 0.86
496+
```{r}
497+
fit <- gam(log(shad) ~ as.factor(year) + s(jday, bs = "cr"),
498+
data = shad22_24)
499+
```
367500

368501
## GAMs and DLMs:
369502

@@ -404,6 +537,7 @@ gridExtra::grid.arrange(p1, p2, ncol = 1)
404537

405538
* What is this model doing? Where have we seen something similar??
406539
```{r echo = TRUE, eval = TRUE, fig.height=3, fig.width=6}
540+
# personal savings = s(time)
407541
fit <- gam(psavert ~ s(time_num, bs = "cr"),
408542
data = economics)
409543
```
@@ -423,13 +557,16 @@ plot(fit)
423557
* How about a smooth / random walk on the covariate. Is this correct and or why not?
424558

425559
```{r echo=TRUE, eval=TRUE}
560+
# savings rate ~ s(unemployment)
426561
fit <- gam(psavert ~ s(ln_unemploy, bs = "cr"),
427562
data = economics)
428563
```
429564

430565
## GAMs and DLMs:
431566

432-
* This is fitting a non-linear smooth of CUI (totally ignoring time)
567+
* This is fitting a flexible non-linear model
568+
569+
* However, non-linear smooth of unemployment totally ignores time aspect
433570

434571
```{r echo = TRUE, eval = TRUE, fig.height=3, fig.width=6}
435572
plot(fit)
@@ -442,17 +579,44 @@ plot(fit)
442579
* The prior model was fitting a non-linear smooth of the covariate
443580

444581
* What about a smooth of the covariate and time?
582+
583+
* Here we add a 2D smooth, `s(ln_unemploy, time_num)`
445584
```{r echo=TRUE, eval=TRUE}
446585
fit <- gam(psavert ~ s(ln_unemploy, time_num),
447586
data = economics)
448587
```
449588

450589
## GAMs and DLMs:
451590

452-
* This is fitting 2D smooth of ln_unemploy and time
591+
* What is the 2D smooth of ln_unemploy and time doing?
592+
593+
* There's maybe some slight non-linearity here
453594

454595
```{r echo = TRUE, eval = TRUE, fig.height=3, fig.width=6}
455-
plot(fit)
596+
vis.gam(fit, view = c("time_num", "ln_unemploy"), plot.type = "contour", color = "heat")
597+
```
598+
599+
## GAMs and DLMs
600+
601+
* In some cases, the 2D smooth is more complciated
602+
603+
```{r echo = TRUE, eval = TRUE, fig.height=3, fig.width=6}
604+
605+
# Simulate predictors
606+
set.seed(123)
607+
n <- 400
608+
x1 <- runif(n, 0, 10)
609+
x2 <- runif(n, 0, 10)
610+
611+
# Create a response with interaction
612+
# The response surface is non-additive: x1's effect depends on x2
613+
y <- sin(x1) * cos(x2) + rnorm(n, sd = 0.3)
614+
615+
dat <- data.frame(x1 = x1, x2 = x2, y = y)
616+
617+
model <- gam(y ~ s(x1, x2), data = dat)
618+
619+
vis.gam(model, view = c("x1", "x2"), plot.type = "contour", color = "topo")
456620
```
457621

458622
## GAMs and DLMs
@@ -465,15 +629,16 @@ plot(fit)
465629

466630
* Another way to model the interaction is with the `by` covariate
467631

468-
* This creates separate smooths for different levels of the `by` covariate
632+
* This lets the smooth vary / creates separate smooths for different values of the `by` covariate
469633

470634
```{r echo=TRUE, eval=FALSE}
471635
s(predictor, by = covariate)
472636
```
473637

474638
## GAMs and DLMs:
475639

476-
* For time series data, it's a common mistake to use
640+
* For time series data, a common mistake is to use
641+
477642
```{r echo=TRUE, eval=TRUE}
478643
fit <- gam(psavert ~ s(ln_unemploy, by = time_num),
479644
data = economics)
@@ -507,6 +672,13 @@ fit <- gam(psavert ~ s(time_num, by = ln_unemploy),
507672

508673
## GAMs and DLMs:
509674

675+
```{r echo = FALSE, eval = TRUE, fig.height=3, fig.width=6}
676+
df <- data.frame("Goal" = c("Effect of x changes over time", "Effect of time changes with x"), "Formula" = c("s(time, by = x)", "s(x, by = time)"))
677+
knitr::kable(df)
678+
```
679+
680+
## GAMs and DLMs:
681+
510682
* Now we have a time-varying slope
511683
```{r echo = TRUE, eval = TRUE, fig.height=3, fig.width=6}
512684
plot(fit)
@@ -600,7 +772,7 @@ rbind(dat1[554:574,], dat2[554:574,]) |>
600772

601773
* `compare_gams` Available in slide Rmd
602774

603-
```{r echo=FALSE}
775+
```{r echo=FALSE, warning=FALSE, message=FALSE}
604776
compare_gams <- function(gam1, gam2, data, response, model_names = c("Non-linear GAM", "Time-varying GAM")) {
605777
models <- list(gam1, gam2)
606778
results <- list()
@@ -642,7 +814,7 @@ compare_gams <- function(gam1, gam2, data, response, model_names = c("Non-linear
642814

643815
## Non-linearity vs non-stationarity
644816

645-
```{r echo=TRUE, message=TRUE, warning=TRUE, results='asis'}
817+
```{r echo=TRUE, message=FALSE, warning=FALSE, results='asis'}
646818
k <- compare_gams(fit_nonlinear, fit_dlm, response="psavert", data = economics)
647819
print(knitr::kable(k, digits = 3, caption = "Comparison of GAMs"))
648820
```
@@ -655,8 +827,10 @@ print(knitr::kable(k, digits = 3, caption = "Comparison of GAMs"))
655827

656828
```{r echo = TRUE}
657829
data("SalmonSurvCUI")
830+
# time varying intercept model
658831
g1 <- gam(logit.s ~ s(year, k = 10),
659832
data = SalmonSurvCUI)
833+
# time varying slope model
660834
g2 <- gam(logit.s ~ s(year, by = CUI.apr, k = 10),
661835
data = SalmonSurvCUI)
662836
```

docs/Lectures/Week 6/lec_11_intro_to_GAMs.html

Lines changed: 124 additions & 47 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)