Skip to content

Commit cf9941c

Browse files
committed
Add update function for vector algebra
1 parent 3125010 commit cf9941c

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

Sources/Numerix/LinearAlgebra/Algebra.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,17 @@ public protocol Algebra {
2626
static func scale(_ a: inout Matrix<Self>, by k: Self)
2727
static func transpose(_ a: Matrix<Self>) -> Matrix<Self>
2828
static func swapValues(a: inout Matrix<Self>, b: inout Matrix<Self>)
29+
30+
static func update(a: inout Vector<Self>, k: Self, b: Vector<Self>)
2931
}
3032

3133
@_documentation(visibility: private)
3234
extension Int: Algebra {
3335

36+
public static func update(a: inout Vector<Int>, k: Int, b: Vector<Int>) {
37+
fatalError("Not supported for integer values")
38+
}
39+
3440
// Vector
3541

3642
public static func swapValues(a: inout Vector<Int>, b: inout Vector<Int>) {
@@ -127,6 +133,10 @@ extension Int: Algebra {
127133
@_documentation(visibility: private)
128134
extension Float: Algebra {
129135

136+
public static func update(a: inout Vector<Float>, k: Float, b: Vector<Float>) {
137+
cblas_saxpy(b.size, k, b.buffer.baseAddress, 1, a.buffer.baseAddress, 1)
138+
}
139+
130140
// Vector
131141

132142
public static func swapValues(a: inout Vector<Float>, b: inout Vector<Float>) {
@@ -225,6 +235,10 @@ extension Float: Algebra {
225235
@_documentation(visibility: private)
226236
extension Double: Algebra {
227237

238+
public static func update(a: inout Vector<Double>, k: Double, b: Vector<Double>) {
239+
cblas_daxpy(b.size, k, b.buffer.baseAddress, 1, a.buffer.baseAddress, 1)
240+
}
241+
228242
// Vector
229243

230244
public static func swapValues(a: inout Vector<Double>, b: inout Vector<Double>) {
@@ -322,6 +336,17 @@ extension Double: Algebra {
322336

323337
extension Vector where Scalar: Algebra {
324338

339+
/// Update the vector with element-wise addition of another vector multiplied by a scalar value.
340+
///
341+
/// This performs the Level 1 BLAS operation axpy which is represented by equation `y = y + αx` where `y` and `x`
342+
/// are vectors and `α` is a scalar value.
343+
/// - Parameters:
344+
/// - b: The vector `b` in `a = a + b * k`.
345+
/// - k: The scalar value `k` in `a = a + b * k`.
346+
public mutating func update(with b: Vector, times k: Scalar = 1) where Scalar: Numeric {
347+
Scalar.update(a: &self, k: k, b: b)
348+
}
349+
325350
/// Calculate the dot product of two vectors.
326351
///
327352
/// This calculates the dot product as `c = aᵀb` where `a` and `b` are vectors

Tests/VectorTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,14 @@ struct VectorTests {
252252
swapValues(&d, &e)
253253
#expect(d == [1, 2, 3, 4])
254254
#expect(e == [8, 9, 10, 11])
255+
256+
var y = Vector<Float>([2, 3, 4, 5, 6, 7])
257+
let x = Vector<Float>([1, 2, 3, 4, 5, 6])
258+
y.update(with: x)
259+
#expect(y == [3, 5, 7, 9, 11, 13])
260+
261+
y.update(with: x, times: 2)
262+
#expect(y == [5, 9, 13, 17, 21, 25])
255263
}
256264

257265
@Test func doubleAlgebra() {

0 commit comments

Comments
 (0)