Skip to content

Commit fb8839e

Browse files
Monthly moving avg (#363)
* added a monthly moving average function along with unit tests * cleaned up comments in monthly moving avg cpp code
1 parent 53fbd80 commit fb8839e

File tree

8 files changed

+508
-1
lines changed

8 files changed

+508
-1
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export(assignment_column)
1212
export(bbox)
1313
export(bin_var)
1414
export(bycatch)
15+
export(calc_monthly_moving_avg)
1516
export(catch_lm)
1617
export(category_cols)
1718
export(centroid_to_fsdb)

R/RcppExports.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ calculate_moving_avg <- function(unique_dates, unique_groups, obs_dates, obs_gro
55
.Call(`_FishSET_calculate_moving_avg`, unique_dates, unique_groups, obs_dates, obs_groups, obs_values, window_size, lag, year_lag, temporal, weighted)
66
}
77

8+
calculate_monthly_avg <- function(unique_years, unique_months, unique_groups, obs_years, obs_months, obs_groups, obs_values, window_size, month_lag, year_lag = 0L, weighted = FALSE) {
9+
.Call(`_FishSET_calculate_monthly_avg`, unique_years, unique_months, unique_groups, obs_years, obs_months, obs_groups, obs_values, window_size, month_lag, year_lag, weighted)
10+
}
11+
812
shift_sort_xcpp <- function(x, ch, y, distance, alts, ab) {
913
.Call(`_FishSET_shift_sort_xcpp`, x, ch, y, distance, alts, ab)
1014
}

R/calc_monthly_moving_avg.R

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#' Calculate monthly moving window averages
2+
#'
3+
#' @param df Primary data frame or data.table containing the observations.
4+
#' @param name String, the name of the new column to be added to \code{df}
5+
#' containing the calculated moving averages. Defaults to \code{"moving_avg"}.
6+
#' @param year_col String, the name of the column containing the year (integer).
7+
#' @param month_col String, the name of the column containing the month (integer).
8+
#' @param group_cols Character vector, variable(s) from \code{df} that define how to
9+
#' group the data for averaging (e.g. \code{c("VESSEL_ID", "ZONE")}). The moving
10+
#' average is calculated independently for each group defined here.
11+
#' @param value_col String, the name of the column containing the numeric values
12+
#' to be averaged.
13+
#' @param window_size Numeric, the size of the moving window in months. Defaults to 3.
14+
#' @param month_lag Numeric, the number of months to lag the window. Defaults to 1.
15+
#' For example, a lag of 1 starts the average from the previous month.
16+
#' @param year_lag Numeric, the number of years to lag the reference date. Defaults to 0.
17+
#' If set to 1, the function looks up the date from exactly one year prior before
18+
#' applying the \code{month_lag} and \code{window_size}.
19+
#' @param fill_empty_expectation Numeric or \code{NA}. Value used to fill \code{NA}
20+
#' values in the resulting moving average column. Defaults to \code{NA}.
21+
#'
22+
#' @importFrom data.table as.data.table := .SD setnames setorder
23+
#'
24+
#' @returns Returns the original \code{df} as a data.table with the additional
25+
#' column specified by \code{name}.
26+
#' @export
27+
28+
calc_monthly_moving_avg <- function(df,
29+
name,
30+
year_col,
31+
month_col,
32+
group_cols,
33+
value_col,
34+
window_size = 3,
35+
month_lag = 1,
36+
year_lag = 0,
37+
fill_empty_expectation = NA) {
38+
39+
dt <- as.data.table(df)
40+
41+
# Create unified group ID
42+
# Use a temp name to avoid clashes
43+
temp_id_col <- "TEMP_UNIFIED_ID_XYZ"
44+
dt[, (temp_id_col) := do.call(paste, c(.SD, sep = "_")), .SDcols = group_cols]
45+
46+
# Prepare unique vectors
47+
unique_times <- unique(dt[, .(get(year_col), get(month_col))])
48+
setnames(unique_times, c("yr", "mth"))
49+
setorder(unique_times, yr, mth)
50+
51+
unique_groups <- unique(dt[[temp_id_col]])
52+
53+
# Call C++ function
54+
res_mat <- calculate_monthly_avg(
55+
unique_years = unique_times$yr,
56+
unique_months = unique_times$mth,
57+
unique_groups = unique_groups,
58+
59+
obs_years = dt[[year_col]],
60+
obs_months = dt[[month_col]],
61+
obs_groups = dt[[temp_id_col]],
62+
obs_values = dt[[value_col]],
63+
64+
window_size = window_size,
65+
month_lag = month_lag,
66+
year_lag = year_lag,
67+
weighted = FALSE
68+
)
69+
70+
# Convert to data table
71+
n_times <- nrow(unique_times)
72+
n_groups <- length(unique_groups)
73+
74+
lookup_table <- data.table(
75+
yr = rep(unique_times$yr, times = n_groups),
76+
mth = rep(unique_times$mth, times = n_groups),
77+
group_id = rep(unique_groups, each = n_times),
78+
val_placeholder = as.vector(res_mat) # Flattens matrix column by column
79+
)
80+
81+
# Rename columns to match the input df and the desired target column name
82+
setnames(lookup_table,
83+
old = c("yr", "mth", "group_id", "val_placeholder"),
84+
new = c(year_col, month_col, temp_id_col, name))
85+
86+
dt <- merge(dt, lookup_table, by = c(year_col, month_col, temp_id_col), all.x = TRUE)
87+
88+
# Handle missing expectations
89+
if (!is.na(fill_empty_expectation)) {
90+
# Replace NAs in the NEW column with the provided value
91+
# We use get() to refer to the column by its string name
92+
dt[is.na(get(name)), (name) := fill_empty_expectation]
93+
}
94+
95+
# Cleanup
96+
dt[, (temp_id_col) := NULL]
97+
98+
return(dt)
99+
}

R/globals.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,7 @@ utils::globalVariables(c(
2626
'model_name',
2727
'policy_name',
2828
'rename',
29-
'V1'
29+
'V1',
30+
'yr',
31+
'mth'
3032
))

man/calc_monthly_moving_avg.Rd

Lines changed: 55 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/RcppExports.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,27 @@ BEGIN_RCPP
3030
return rcpp_result_gen;
3131
END_RCPP
3232
}
33+
// calculate_monthly_avg
34+
NumericMatrix calculate_monthly_avg(IntegerVector unique_years, IntegerVector unique_months, CharacterVector unique_groups, IntegerVector obs_years, IntegerVector obs_months, CharacterVector obs_groups, NumericVector obs_values, int window_size, int month_lag, int year_lag, bool weighted);
35+
RcppExport SEXP _FishSET_calculate_monthly_avg(SEXP unique_yearsSEXP, SEXP unique_monthsSEXP, SEXP unique_groupsSEXP, SEXP obs_yearsSEXP, SEXP obs_monthsSEXP, SEXP obs_groupsSEXP, SEXP obs_valuesSEXP, SEXP window_sizeSEXP, SEXP month_lagSEXP, SEXP year_lagSEXP, SEXP weightedSEXP) {
36+
BEGIN_RCPP
37+
Rcpp::RObject rcpp_result_gen;
38+
Rcpp::RNGScope rcpp_rngScope_gen;
39+
Rcpp::traits::input_parameter< IntegerVector >::type unique_years(unique_yearsSEXP);
40+
Rcpp::traits::input_parameter< IntegerVector >::type unique_months(unique_monthsSEXP);
41+
Rcpp::traits::input_parameter< CharacterVector >::type unique_groups(unique_groupsSEXP);
42+
Rcpp::traits::input_parameter< IntegerVector >::type obs_years(obs_yearsSEXP);
43+
Rcpp::traits::input_parameter< IntegerVector >::type obs_months(obs_monthsSEXP);
44+
Rcpp::traits::input_parameter< CharacterVector >::type obs_groups(obs_groupsSEXP);
45+
Rcpp::traits::input_parameter< NumericVector >::type obs_values(obs_valuesSEXP);
46+
Rcpp::traits::input_parameter< int >::type window_size(window_sizeSEXP);
47+
Rcpp::traits::input_parameter< int >::type month_lag(month_lagSEXP);
48+
Rcpp::traits::input_parameter< int >::type year_lag(year_lagSEXP);
49+
Rcpp::traits::input_parameter< bool >::type weighted(weightedSEXP);
50+
rcpp_result_gen = Rcpp::wrap(calculate_monthly_avg(unique_years, unique_months, unique_groups, obs_years, obs_months, obs_groups, obs_values, window_size, month_lag, year_lag, weighted));
51+
return rcpp_result_gen;
52+
END_RCPP
53+
}
3354
// shift_sort_xcpp
3455
NumericMatrix shift_sort_xcpp(NumericMatrix x, NumericMatrix ch, NumericVector y, NumericMatrix distance, int alts, int ab);
3556
RcppExport SEXP _FishSET_shift_sort_xcpp(SEXP xSEXP, SEXP chSEXP, SEXP ySEXP, SEXP distanceSEXP, SEXP altsSEXP, SEXP abSEXP) {
@@ -49,6 +70,7 @@ END_RCPP
4970

5071
static const R_CallMethodDef CallEntries[] = {
5172
{"_FishSET_calculate_moving_avg", (DL_FUNC) &_FishSET_calculate_moving_avg, 10},
73+
{"_FishSET_calculate_monthly_avg", (DL_FUNC) &_FishSET_calculate_monthly_avg, 11},
5274
{"_FishSET_shift_sort_xcpp", (DL_FUNC) &_FishSET_shift_sort_xcpp, 6},
5375
{NULL, NULL, 0}
5476
};

src/monthly_moving_avg.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#include <Rcpp.h>
2+
#include <map>
3+
#include <vector>
4+
#include <string>
5+
#include <algorithm>
6+
7+
using namespace Rcpp;
8+
9+
// Helper to convert Year/Month to an absolute month index (0-based from Year 0)
10+
// e.g., 2023, 1 -> 24276
11+
int get_abs_month(int year, int month) {
12+
return year * 12 + (month - 1);
13+
}
14+
15+
// [[Rcpp::export]]
16+
NumericMatrix calculate_monthly_avg(IntegerVector unique_years,
17+
IntegerVector unique_months,
18+
CharacterVector unique_groups,
19+
IntegerVector obs_years,
20+
IntegerVector obs_months,
21+
CharacterVector obs_groups,
22+
NumericVector obs_values,
23+
int window_size,
24+
int month_lag,
25+
int year_lag = 0,
26+
bool weighted = false) {
27+
28+
int n_times = unique_years.size();
29+
int n_groups = unique_groups.size();
30+
int n_obs = obs_years.size();
31+
32+
// --- Map Setup ---
33+
// Map "Absolute Month" -> Row Index
34+
std::map<int, int> time_to_row_map;
35+
for(int i = 0; i < n_times; ++i) {
36+
int abs_m = get_abs_month(unique_years[i], unique_months[i]);
37+
time_to_row_map[abs_m] = i;
38+
}
39+
40+
// Map "Group Name" -> Column Index
41+
std::map<String, int> group_to_col_map;
42+
for(int i = 0; i < n_groups; ++i) {
43+
group_to_col_map[unique_groups[i]] = i;
44+
}
45+
46+
// --- Data Aggregation (Handle multiple obs per month/group) ---
47+
// Initialize matrices
48+
NumericMatrix values_matrix(n_times, n_groups);
49+
std::fill(values_matrix.begin(), values_matrix.end(), 0.0);
50+
51+
IntegerMatrix count_matrix(n_times, n_groups);
52+
std::fill(count_matrix.begin(), count_matrix.end(), 0);
53+
54+
for(int i = 0; i < n_obs; ++i) {
55+
int abs_m = get_abs_month(obs_years[i], obs_months[i]);
56+
57+
// Safety check: ensure this time/group exists in our definition vectors
58+
if (time_to_row_map.find(abs_m) != time_to_row_map.end() &&
59+
group_to_col_map.find(obs_groups[i]) != group_to_col_map.end()) {
60+
61+
int row = time_to_row_map[abs_m];
62+
int col = group_to_col_map[obs_groups[i]];
63+
64+
if (!NumericVector::is_na(obs_values[i])) {
65+
values_matrix(row, col) += obs_values[i];
66+
count_matrix(row, col)++;
67+
}
68+
}
69+
}
70+
71+
// Calculate raw monthly means (normalize the aggregated sums)
72+
for (int r = 0; r < n_times; ++r) {
73+
for (int c = 0; c < n_groups; ++c) {
74+
if (count_matrix(r, c) > 0) {
75+
values_matrix(r, c) = values_matrix(r, c) / count_matrix(r, c);
76+
} else {
77+
values_matrix(r, c) = NA_REAL;
78+
}
79+
}
80+
}
81+
82+
// --- Sliding Window Calculation ---
83+
84+
NumericMatrix result_matrix(n_times, n_groups);
85+
std::fill(result_matrix.begin(), result_matrix.end(), NA_REAL);
86+
87+
for (int c = 0; c < n_groups; ++c) {
88+
for (int r = 0; r < n_times; ++r) {
89+
90+
// Identify the anchor time (Current Time - Year Lag)
91+
int current_abs = get_abs_month(unique_years[r], unique_months[r]);
92+
int anchor_abs = current_abs - (year_lag * 12);
93+
94+
double sum_product = 0.0;
95+
double sum_divisor = 0.0;
96+
97+
// Define Window (based on Month Lag)
98+
// Loop through the offsets for the window size
99+
for (int w = 0; w < window_size; ++w) {
100+
int lag_amount = month_lag + w;
101+
int target_abs = anchor_abs - lag_amount;
102+
103+
// Check if this specific target month exists in our data
104+
auto it = time_to_row_map.find(target_abs);
105+
106+
if (it != time_to_row_map.end()) {
107+
int target_row = it->second;
108+
double val = values_matrix(target_row, c);
109+
110+
if (!NumericVector::is_na(val)) {
111+
if (weighted) {
112+
int weight = count_matrix(target_row, c);
113+
sum_product += val * weight;
114+
sum_divisor += weight;
115+
} else {
116+
sum_product += val;
117+
sum_divisor += 1.0;
118+
}
119+
}
120+
}
121+
}
122+
123+
if (sum_divisor > 0) {
124+
result_matrix(r, c) = sum_product / sum_divisor;
125+
}
126+
}
127+
}
128+
129+
// Set Names
130+
colnames(result_matrix) = unique_groups;
131+
132+
return result_matrix;
133+
}

0 commit comments

Comments
 (0)