-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathreverse.rs
More file actions
176 lines (156 loc) · 6.85 KB
/
reverse.rs
File metadata and controls
176 lines (156 loc) · 6.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
use std::collections::HashMap;
use crate::float::Float;
use crate::opcode::{self, OpCode, UNUSED};
impl<F: Float> super::BytecodeTape<F> {
/// Reverse sweep with weighted seeds for multiple outputs.
///
/// Computes `∑_i weights[i] * ∂output_i/∂x` — a vector-Jacobian product.
///
/// Returns the gradient with respect to all inputs (length [`num_inputs`](Self::num_inputs)).
pub fn reverse_seeded(&self, seeds: &[F]) -> Vec<F> {
let out_indices = self.all_output_indices();
assert_eq!(
seeds.len(),
out_indices.len(),
"seeds length must match number of outputs"
);
let ni = self.num_inputs as usize;
let adjoints = self.reverse_seeded_full(seeds, out_indices);
adjoints[..ni].to_vec()
}
/// Core reverse sweep loop shared by all scalar reverse sweep variants.
///
/// Expects `adjoints` to be pre-seeded by the caller (length = `num_variables`).
/// Reads primal values from `values` (either `self.values` or an external buffer).
/// When `forced_signs` is `Some`, uses forced partials at matching tape indices.
pub(super) fn reverse_sweep_core(
&self,
adjoints: &mut [F],
values: &[F],
forced_signs: Option<&HashMap<u32, i8>>,
) {
for i in (0..self.opcodes.len()).rev() {
let adj = adjoints[i];
if adj == F::zero() {
continue;
}
match self.opcodes[i] {
OpCode::Input | OpCode::Const => continue,
OpCode::Custom => {
adjoints[i] = F::zero();
let [a_idx, cb_idx] = self.arg_indices[i];
let a = values[a_idx as usize];
let b_idx_opt = self.custom_second_args.get(&(i as u32)).copied();
let b = b_idx_opt.map(|bi| values[bi as usize]).unwrap_or(F::zero());
let r = values[i];
let (da, db) = self.custom_ops[cb_idx as usize].partials(a, b, r);
adjoints[a_idx as usize] = adjoints[a_idx as usize] + da * adj;
if let Some(bi) = b_idx_opt {
adjoints[bi as usize] = adjoints[bi as usize] + db * adj;
}
}
op => {
adjoints[i] = F::zero();
let [a_idx, b_idx] = self.arg_indices[i];
let a = values[a_idx as usize];
if op == OpCode::Powi {
let exp = opcode::powi_exp_decode_raw(b_idx);
let n = F::from(exp).unwrap();
let da = n * a.powi(exp - 1);
adjoints[a_idx as usize] = adjoints[a_idx as usize] + da * adj;
continue;
}
let b = if b_idx != UNUSED {
values[b_idx as usize]
} else {
F::zero()
};
let r = values[i];
let (da, db) = match forced_signs.and_then(|fs| fs.get(&(i as u32))) {
Some(&sign) => opcode::forced_reverse_partials(op, a, b, r, sign),
None => opcode::reverse_partials(op, a, b, r),
};
adjoints[a_idx as usize] = adjoints[a_idx as usize] + da * adj;
if b_idx != UNUSED {
adjoints[b_idx as usize] = adjoints[b_idx as usize] + db * adj;
}
}
}
}
}
/// Reverse sweep: compute adjoints seeded at the output.
///
/// Returns the full adjoint vector (length = `num_variables`).
#[must_use]
pub fn reverse(&self, seed_index: u32) -> Vec<F> {
let n = self.num_variables as usize;
let mut adjoints = vec![F::zero(); n];
adjoints[seed_index as usize] = F::one();
self.reverse_sweep_core(&mut adjoints, &self.values, None);
adjoints
}
/// Reverse sweep with forced branch choices at specified tape indices.
pub(super) fn reverse_with_forced_signs(
&self,
seed_index: u32,
forced_signs: &HashMap<u32, i8>,
) -> Vec<F> {
let n = self.num_variables as usize;
let mut adjoints = vec![F::zero(); n];
adjoints[seed_index as usize] = F::one();
self.reverse_sweep_core(&mut adjoints, &self.values, Some(forced_signs));
adjoints
}
/// Reverse sweep reading from an external values buffer.
///
/// Like [`reverse`](Self::reverse) but reads primal values from `values`
/// instead of `self.values`. Pair with [`forward_into`](Self::forward_into)
/// for parallel evaluation.
pub fn reverse_from(&self, values: &[F], seed_index: u32) -> Vec<F> {
let n = self.num_variables as usize;
assert_eq!(values.len(), n, "values buffer has wrong length");
let mut adjoints = vec![F::zero(); n];
adjoints[seed_index as usize] = F::one();
self.reverse_sweep_core(&mut adjoints, values, None);
adjoints
}
/// Forward + reverse: compute the gradient at new inputs.
///
/// Returns only the input adjoints (indices `0..num_inputs`).
pub fn gradient(&mut self, inputs: &[F]) -> Vec<F> {
self.forward(inputs);
let adjoints = self.reverse(self.output_index);
adjoints[..self.num_inputs as usize].to_vec()
}
/// Like [`gradient`](Self::gradient) but reuses a caller-provided buffer
/// for the adjoint vector, avoiding allocation on repeated calls.
pub fn gradient_with_buf(&mut self, inputs: &[F], adjoint_buf: &mut Vec<F>) -> Vec<F> {
self.forward(inputs);
let n = self.num_variables as usize;
adjoint_buf.clear();
adjoint_buf.resize(n, F::zero());
adjoint_buf[self.output_index as usize] = F::one();
self.reverse_sweep_core(adjoint_buf, &self.values, None);
adjoint_buf[..self.num_inputs as usize].to_vec()
}
/// Reverse sweep with weighted seeds, returning full adjoint vector.
pub(super) fn reverse_seeded_full(&self, seeds: &[F], out_indices: &[u32]) -> Vec<F> {
let n = self.num_variables as usize;
let mut adjoints = vec![F::zero(); n];
for (&out_idx, &weight) in out_indices.iter().zip(seeds.iter()) {
if weight == F::zero() {
continue;
}
adjoints[out_idx as usize] = adjoints[out_idx as usize] + weight;
}
self.reverse_sweep_core(&mut adjoints, &self.values, None);
adjoints
}
// ── Batch evaluation ──
/// Evaluate the gradient at multiple input points.
///
/// Returns one gradient vector per input point.
pub fn gradient_batch(&mut self, inputs: &[&[F]]) -> Vec<Vec<F>> {
inputs.iter().map(|x| self.gradient(x)).collect()
}
}