|
| 1 | +/** |
| 2 | + * 为密集计算设计的高性能非负最小二乘求解器 |
| 3 | + */ |
| 4 | +class NNLSSolver { |
| 5 | + /** |
| 6 | + * @param {number} K 个数 |
| 7 | + * @param {number} M 维度 |
| 8 | + * @param {number} lambda 正则化参数 防止不稳定 |
| 9 | + */ |
| 10 | + constructor(K, M, lambda = 1e-4) { |
| 11 | + this.K = K; |
| 12 | + this.M = M; |
| 13 | + this.lambda = lambda; |
| 14 | + // 预分配内存 |
| 15 | + this.c = new Float32Array(K); // 最终系数 (K) |
| 16 | + this.s = new Float32Array(K); // 候选系数 (K) |
| 17 | + this.w = new Float32Array(K); // 梯度 (K) |
| 18 | + this.residual = new Float32Array(M); // 增量残差 (M) |
| 19 | + this.matM = new Float32Array(M * M); // 正规方程矩阵 (M*M) |
| 20 | + this.rhsM = new Float32Array(M); // 正规方程右侧向量 (M) |
| 21 | + this.L = new Float32Array(M * M); // Cholesky 分解矩阵 (M*M) |
| 22 | + this.z = new Float32Array(M); // 临时连续解向量 (M) |
| 23 | + this.isP = new Uint8Array(K); |
| 24 | + this.pIdx = new Int32Array(M); |
| 25 | + } |
| 26 | + |
| 27 | + /** |
| 28 | + * 求解非负最小二乘问题 min ||Ax - b||_2^2 s.t. x >= 0 |
| 29 | + * @param {Float32Array} A 每M个数为一组,一共K组 |
| 30 | + * @param {Float32Array} b 长M |
| 31 | + * @returns {Float32Array} 长K的非负系数向量x 是this.c的引用 |
| 32 | + */ |
| 33 | + solve(A, b) { |
| 34 | + const { K, M, c, s, w, residual, isP, pIdx } = this; |
| 35 | + c.fill(0); |
| 36 | + isP.fill(0); |
| 37 | + residual.set(b); |
| 38 | + let pCount = 0; |
| 39 | + const tol = 1e-7 * M; // 根据维度动态调整容差 |
| 40 | + for (let iter = 0, maxIter = K << 1; iter < maxIter; iter++) { |
| 41 | + // 1. 计算梯度 w = A^T * residual |
| 42 | + let maxW = -1, jMax = -1; |
| 43 | + for (let j = 0; j < K; j++) { |
| 44 | + if (isP[j]) continue; |
| 45 | + let dot = 0; |
| 46 | + const offset = j * M; |
| 47 | + for (let i = 0; i < M; i++) dot += A[offset + i] * residual[i]; |
| 48 | + w[j] = dot; |
| 49 | + if (dot > maxW) { maxW = dot; jMax = j; } |
| 50 | + } |
| 51 | + if (jMax === -1 || maxW < tol) break; |
| 52 | + isP[jMax] = 1; |
| 53 | + pIdx[pCount++] = jMax; |
| 54 | + while (pCount > 0) { |
| 55 | + // 求解子问题,结果暂存在 s 中 |
| 56 | + if (!this._solveSubProblem(A, b, pCount, pIdx, s)) { |
| 57 | + const last = pIdx[--pCount]; |
| 58 | + isP[last] = c[last] = 0; |
| 59 | + break; |
| 60 | + } |
| 61 | + let alpha = 2.0; |
| 62 | + let hasConstraintViolation = false; |
| 63 | + for (let i = 0; i < pCount; i++) { |
| 64 | + const idx = pIdx[i]; |
| 65 | + if (s[idx] <= 0) { |
| 66 | + const ratio = c[idx] / (c[idx] - s[idx] + 1e-15); |
| 67 | + if (ratio < alpha) { |
| 68 | + alpha = ratio; |
| 69 | + hasConstraintViolation = true; |
| 70 | + } |
| 71 | + } |
| 72 | + } |
| 73 | + if (!hasConstraintViolation) { |
| 74 | + // 无冲突:更新残差并接受新系数 |
| 75 | + this._updateResidual(A, c, s, pCount, pIdx); |
| 76 | + for (let i = 0; i < pCount; i++) c[pIdx[i]] = s[pIdx[i]]; |
| 77 | + break; |
| 78 | + } |
| 79 | + // 有冲突:按 alpha 步长靠近,并剔除归零的变量 |
| 80 | + for (let i = 0; i < pCount; i++) { |
| 81 | + const idx = pIdx[i]; |
| 82 | + c[idx] += alpha * (s[idx] - c[idx]); |
| 83 | + } |
| 84 | + for (let i = 0; i < pCount; i++) { |
| 85 | + const idx = pIdx[i]; |
| 86 | + if (c[idx] < 1e-9) { // 稍微放宽归零判定 |
| 87 | + c[idx] = 0; |
| 88 | + isP[idx] = 0; |
| 89 | + pIdx[i] = pIdx[--pCount]; |
| 90 | + i--; |
| 91 | + } |
| 92 | + } |
| 93 | + this._fullResidualUpdate(A, b, c, pCount, pIdx); |
| 94 | + } |
| 95 | + } return c; |
| 96 | + } |
| 97 | + |
| 98 | + _solveSubProblem(A, b, n, pIdx, s) { |
| 99 | + const { M, matM, rhsM, L, z, lambda } = this; |
| 100 | + // 1. 构建正规方程 |
| 101 | + for (let i = 0; i < n; i++) { |
| 102 | + const offI = pIdx[i] * M; |
| 103 | + let dotB = 0; |
| 104 | + for (let r = 0; r < M; r++) dotB += A[offI + r] * b[r]; |
| 105 | + rhsM[i] = dotB; |
| 106 | + for (let j = 0; j <= i; j++) { |
| 107 | + const offJ = pIdx[j] * M; |
| 108 | + let dotA = 0; |
| 109 | + for (let r = 0; r < M; r++) dotA += A[offI + r] * A[offJ + r]; |
| 110 | + if (i === j) dotA += lambda; |
| 111 | + matM[i * n + j] = dotA; |
| 112 | + } |
| 113 | + } |
| 114 | + // 2. Cholesky 分解 |
| 115 | + for (let i = 0; i < n; i++) { |
| 116 | + for (let j = 0; j <= i; j++) { |
| 117 | + let sum = matM[i * n + j]; |
| 118 | + for (let k = 0; k < j; k++) sum -= L[i * n + k] * L[j * n + k]; |
| 119 | + if (i === j) { |
| 120 | + if (sum <= 0) return false; |
| 121 | + L[i * n + j] = Math.sqrt(sum); |
| 122 | + } else { |
| 123 | + L[i * n + j] = sum / L[j * n + j]; |
| 124 | + } |
| 125 | + } |
| 126 | + } |
| 127 | + // 3. 前向替换 (L * y = rhsM -> 结果存入 z) |
| 128 | + for (let i = 0; i < n; i++) { |
| 129 | + let sum = rhsM[i]; |
| 130 | + for (let k = 0; k < i; k++) sum -= L[i * n + k] * z[k]; |
| 131 | + z[i] = sum / L[i * n + i]; |
| 132 | + } |
| 133 | + // 4. 后向替换 (L^T * x = z -> 结果存入 z) |
| 134 | + for (let i = n - 1; i >= 0; i--) { |
| 135 | + let sum = z[i]; |
| 136 | + for (let k = i + 1; k < n; k++) sum -= L[k * n + i] * z[k]; |
| 137 | + z[i] = sum / L[i * n + i]; |
| 138 | + } |
| 139 | + // 5. 映射回原始大向量 s |
| 140 | + s.fill(0); // 必须清零,因为s共享 |
| 141 | + for (let i = 0; i < n; i++) { |
| 142 | + s[pIdx[i]] = z[i]; |
| 143 | + } return true; |
| 144 | + } |
| 145 | + _updateResidual(A, oldC, newS, n, pIdx) { |
| 146 | + const { M, residual } = this; |
| 147 | + for (let i = 0; i < n; i++) { |
| 148 | + const idx = pIdx[i]; |
| 149 | + const delta = newS[idx] - oldC[idx]; |
| 150 | + if (Math.abs(delta) < 1e-14) continue; |
| 151 | + const offset = idx * M; |
| 152 | + for (let r = 0; r < M; r++) residual[r] -= A[offset + r] * delta; |
| 153 | + } |
| 154 | + } |
| 155 | + _fullResidualUpdate(A, b, c, n, pIdx) { |
| 156 | + const { M, residual } = this; |
| 157 | + residual.set(b); |
| 158 | + for (let i = 0; i < n; i++) { |
| 159 | + const idx = pIdx[i]; |
| 160 | + if (c[idx] === 0) continue; |
| 161 | + const offset = idx * M; |
| 162 | + for (let r = 0; r < M; r++) residual[r] -= A[offset + r] * c[idx]; |
| 163 | + } |
| 164 | + } |
| 165 | + // 在调用 solve() 之后可以使用此函数获取当前残差的 L2 范数 |
| 166 | + calcError() { |
| 167 | + let sum = 0; |
| 168 | + for (let i = 0; i < this.M; i++) sum += this.residual[i] ** 2; |
| 169 | + return Math.sqrt(sum); |
| 170 | + } |
| 171 | +} |
0 commit comments