Skip to content

Commit b3acdcc

Browse files
committed
Refactor Arithmetic protocol and extensions into multiple files
1 parent a36aa65 commit b3acdcc

File tree

8 files changed

+1142
-1118
lines changed

8 files changed

+1142
-1118
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
Arithmetic protocol.
3+
*/
4+
5+
import Accelerate
6+
7+
infix operator .* : MultiplicationPrecedence
8+
infix operator : MultiplicationPrecedence
9+
10+
@_documentation(visibility: private)
11+
public protocol Arithmetic {
12+
// A + k
13+
static func add(_ a: Vector<Self>, _ k: Self) -> Vector<Self>
14+
static func add(_ a: Matrix<Self>, _ k: Self) -> Matrix<Self>
15+
static func add(_ a: ShapedArray<Self>, _ k: Self) -> ShapedArray<Self>
16+
17+
// A + B
18+
static func add(_ a: Vector<Self>, _ b: Vector<Self>) -> Vector<Self>
19+
static func add(_ a: Matrix<Self>, _ b: Matrix<Self>) -> Matrix<Self>
20+
static func add(_ a: ShapedArray<Self>, _ b: ShapedArray<Self>) -> ShapedArray<Self>
21+
22+
// k - A
23+
static func subtract(_ k: Self, _ a: Vector<Self>) -> Vector<Self>
24+
static func subtract(_ k: Self, _ a: Matrix<Self>) -> Matrix<Self>
25+
static func subtract(_ k: Self, _ a: ShapedArray<Self>) -> ShapedArray<Self>
26+
27+
// A - k
28+
static func subtract(_ a: Vector<Self>, _ k: Self) -> Vector<Self>
29+
static func subtract(_ a: Matrix<Self>, _ k: Self) -> Matrix<Self>
30+
static func subtract(_ a: ShapedArray<Self>, _ k: Self) -> ShapedArray<Self>
31+
32+
// A - B
33+
static func subtract(_ a: Vector<Self>, _ b: Vector<Self>) -> Vector<Self>
34+
static func subtract(_ a: Matrix<Self>, _ b: Matrix<Self>) -> Matrix<Self>
35+
static func subtract(_ a: ShapedArray<Self>, _ b: ShapedArray<Self>) -> ShapedArray<Self>
36+
37+
// A * k
38+
static func multiply(_ a: Vector<Self>, _ k: Self) -> Vector<Self>
39+
static func multiply(_ a: Matrix<Self>, _ k: Self) -> Matrix<Self>
40+
static func multiply(_ a: ShapedArray<Self>, _ k: Self) -> ShapedArray<Self>
41+
42+
// A * B
43+
static func multiply(_ a: Vector<Self>, _ b: Vector<Self>) -> Vector<Self>
44+
static func multiply(_ a: Matrix<Self>, _ b: Matrix<Self>) -> Matrix<Self>
45+
static func multiply(_ a: ShapedArray<Self>, _ b: ShapedArray<Self>) -> ShapedArray<Self>
46+
47+
// A x B
48+
static func matrixMultiply(_ a: Matrix<Self>, _ b: Matrix<Self>) -> Matrix<Self>
49+
50+
// k / A
51+
static func divide(_ k: Self, _ a: Vector<Self>) -> Vector<Self>
52+
static func divide(_ k: Self, _ a: Matrix<Self>) -> Matrix<Self>
53+
static func divide(_ k: Self, _ a: ShapedArray<Self>) -> ShapedArray<Self>
54+
55+
// A / k
56+
static func divide(_ a: Vector<Self>, _ k: Self) -> Vector<Self>
57+
static func divide(_ a: Matrix<Self>, _ k: Self) -> Matrix<Self>
58+
static func divide(_ a: ShapedArray<Self>, _ k: Self) -> ShapedArray<Self>
59+
60+
// A / B
61+
static func divide(_ a: Vector<Self>, _ b: Vector<Self>) -> Vector<Self>
62+
static func divide(_ a: Matrix<Self>, _ b: Matrix<Self>) -> Matrix<Self>
63+
static func divide(_ a: ShapedArray<Self>, _ b: ShapedArray<Self>) -> ShapedArray<Self>
64+
}
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
/*
2+
Double extension to conform to Arithmetic protocol.
3+
*/
4+
5+
import Accelerate
6+
7+
@_documentation(visibility: private)
8+
extension Double: Arithmetic {
9+
// A + k
10+
public static func add(_ a: Vector<Double>, _ k: Double) -> Vector<Double> {
11+
var vec = Vector(like: a)
12+
vDSP.add(k, a.buffer, result: &vec.buffer)
13+
return vec
14+
}
15+
16+
public static func add(_ a: Matrix<Double>, _ k: Double) -> Matrix<Double> {
17+
var mat = Matrix(like: a)
18+
vDSP.add(k, a.buffer, result: &mat.buffer)
19+
return mat
20+
}
21+
22+
public static func add(_ a: ShapedArray<Double>, _ k: Double) -> ShapedArray<Double> {
23+
var arr = ShapedArray<Double>(shape: a.shape)
24+
vDSP.add(k, a.buffer, result: &arr.buffer)
25+
return arr
26+
}
27+
28+
// A + B
29+
public static func add(_ a: Vector<Double>, _ b: Vector<Double>) -> Vector<Double> {
30+
var vec = Vector(like: a)
31+
vDSP.add(a.buffer, b.buffer, result: &vec.buffer)
32+
return vec
33+
}
34+
35+
public static func add(_ a: Matrix<Double>, _ b: Matrix<Double>) -> Matrix<Double> {
36+
var mat = Matrix(like: a)
37+
vDSP.add(a.buffer, b.buffer, result: &mat.buffer)
38+
return mat
39+
}
40+
41+
public static func add(_ a: ShapedArray<Double>, _ b: ShapedArray<Double>) -> ShapedArray<Double> {
42+
var arr = ShapedArray<Double>(shape: a.shape)
43+
vDSP.add(a.buffer, b.buffer, result: &arr.buffer)
44+
return arr
45+
}
46+
47+
// k - A
48+
public static func subtract(_ k: Double, _ a: Vector<Double>) -> Vector<Double> {
49+
let arr = Array(repeating: k, count: a.size)
50+
var res = Vector(like: a)
51+
vDSP.subtract(arr, a.buffer, result: &res.buffer)
52+
return res
53+
}
54+
55+
public static func subtract(_ k: Double, _ a: Matrix<Double>) -> Matrix<Double> {
56+
let arr = Array(repeating: k, count: a.buffer.count)
57+
var mat = Matrix(like: a)
58+
vDSP.subtract(arr, a.buffer, result: &mat.buffer)
59+
return mat
60+
}
61+
62+
public static func subtract(_ k: Double, _ a: ShapedArray<Double>) -> ShapedArray<Double> {
63+
let arr = Array(repeating: k, count: a.buffer.count)
64+
var result = ShapedArray<Double>(shape: a.shape)
65+
vDSP.subtract(arr, a.buffer, result: &result.buffer)
66+
return result
67+
}
68+
69+
// A - k
70+
public static func subtract(_ a: Vector<Double>, _ k: Double) -> Vector<Double> {
71+
let arr = Array(repeating: k, count: a.size)
72+
var res = Vector(like: a)
73+
vDSP.subtract(a.buffer, arr, result: &res.buffer)
74+
return res
75+
}
76+
77+
public static func subtract(_ a: Matrix<Double>, _ k: Double) -> Matrix<Double> {
78+
let arr = Array(repeating: k, count: a.buffer.count)
79+
var mat = Matrix(like: a)
80+
vDSP.subtract(a.buffer, arr, result: &mat.buffer)
81+
return mat
82+
}
83+
84+
public static func subtract(_ a: ShapedArray<Double>, _ k: Double) -> ShapedArray<Double> {
85+
let arr = Array(repeating: k, count: a.buffer.count)
86+
var result = ShapedArray<Double>(shape: a.shape)
87+
vDSP.subtract(arr, a.buffer, result: &result.buffer)
88+
return result
89+
}
90+
91+
// A - B
92+
public static func subtract(_ a: Vector<Double>, _ b: Vector<Double>) -> Vector<Double> {
93+
var res = Vector(like: a)
94+
vDSP.subtract(a.buffer, b.buffer, result: &res.buffer)
95+
return res
96+
}
97+
98+
public static func subtract(_ a: Matrix<Double>, _ b: Matrix<Double>) -> Matrix<Double> {
99+
var mat = Matrix(like: a)
100+
vDSP.subtract(a.buffer, b.buffer, result: &mat.buffer)
101+
return mat
102+
}
103+
104+
public static func subtract(_ a: ShapedArray<Double>, _ b: ShapedArray<Double>) -> ShapedArray<Double> {
105+
precondition(a.shape == b.shape, "Shaped arrays must have same shape")
106+
var result = ShapedArray<Double>(shape: a.shape)
107+
vDSP.subtract(a.buffer, b.buffer, result: &result.buffer)
108+
return result
109+
}
110+
111+
// A * k
112+
public static func multiply(_ a: Vector<Double>, _ k: Double) -> Vector<Double> {
113+
var vec = Vector(like: a)
114+
vDSP.multiply(k, a.buffer, result: &vec.buffer)
115+
return vec
116+
}
117+
118+
public static func multiply(_ a: Matrix<Double>, _ k: Double) -> Matrix<Double> {
119+
var mat = Matrix(like: a)
120+
vDSP.multiply(k, a.buffer, result: &mat.buffer)
121+
return mat
122+
}
123+
124+
public static func multiply(_ a: ShapedArray<Double>, _ k: Double) -> ShapedArray<Double> {
125+
var array = ShapedArray<Double>(shape: a.shape)
126+
vDSP.multiply(k, a.buffer, result: &array.buffer)
127+
return array
128+
}
129+
130+
// A * B
131+
public static func multiply(_ a: Vector<Double>, _ b: Vector<Double>) -> Vector<Double> {
132+
var vec = Vector(like: a)
133+
vDSP.multiply(a.buffer, b.buffer, result: &vec.buffer)
134+
return vec
135+
}
136+
137+
public static func multiply(_ a: Matrix<Double>, _ b: Matrix<Double>) -> Matrix<Double> {
138+
var mat = Matrix(like: a)
139+
vDSP.multiply(a.buffer, b.buffer, result: &mat.buffer)
140+
return mat
141+
}
142+
143+
public static func multiply(_ a: ShapedArray<Double>, _ b: ShapedArray<Double>) -> ShapedArray<Double> {
144+
var array = ShapedArray<Double>(shape: a.shape)
145+
vDSP.multiply(a.buffer, b.buffer, result: &array.buffer)
146+
return array
147+
}
148+
149+
// A x B
150+
public static func matrixMultiply(_ a: Matrix<Double>, _ b: Matrix<Double>) -> Matrix<Double> {
151+
precondition(a.columns == b.rows, "Number of columns in matrix A must equal number of rows in matrix B")
152+
153+
let m = a.rows // number of rows in matrices A and C
154+
let n = b.columns // number of columns in matrices B and C
155+
let k = a.columns // number of columns in matrix A; number of rows in matrix B
156+
let alpha = 1.0
157+
let beta = 0.0
158+
159+
// matrix multiplication where C ← αAB + βC
160+
let c = Matrix<Double>(rows: a.rows, columns: b.columns)
161+
cblas_dgemm(
162+
CblasRowMajor, CblasNoTrans, CblasNoTrans,
163+
m, n, k, alpha,
164+
a.buffer.baseAddress, k,
165+
b.buffer.baseAddress, n,
166+
beta,
167+
c.buffer.baseAddress, n
168+
)
169+
170+
return c
171+
}
172+
173+
// k / A
174+
public static func divide(_ k: Double, _ a: Vector<Double>) -> Vector<Double> {
175+
var vec = Vector(like: a)
176+
for i in 0..<a.size {
177+
vec[i] = k / a[i]
178+
}
179+
return vec
180+
}
181+
182+
public static func divide(_ k: Double, _ a: Matrix<Double>) -> Matrix<Double> {
183+
var mat = Matrix(like: a)
184+
vDSP.divide(k, a.buffer, result: &mat.buffer)
185+
return mat
186+
}
187+
188+
public static func divide(_ k: Double, _ a: ShapedArray<Double>) -> ShapedArray<Double> {
189+
var arr = ShapedArray<Double>(shape: a.shape)
190+
vDSP.divide(k, a.buffer, result: &arr.buffer)
191+
return arr
192+
}
193+
194+
// A / k
195+
public static func divide(_ a: Vector<Double>, _ k: Double) -> Vector<Double> {
196+
var vec = Vector(like: a)
197+
for i in 0..<a.size {
198+
vec[i] = a[i] / k
199+
}
200+
return vec
201+
}
202+
203+
public static func divide(_ a: Matrix<Double>, _ k: Double) -> Matrix<Double> {
204+
var mat = Matrix(like: a)
205+
vDSP.divide(a.buffer, k, result: &mat.buffer)
206+
return mat
207+
}
208+
209+
public static func divide(_ a: ShapedArray<Double>, _ k: Double) -> ShapedArray<Double> {
210+
var arr = ShapedArray<Double>(shape: a.shape)
211+
vDSP.divide(a.buffer, k, result: &arr.buffer)
212+
return arr
213+
}
214+
215+
// A / B
216+
public static func divide(_ a: Vector<Self>, _ b: Vector<Self>) -> Vector<Self> {
217+
var vec = Vector(like: a)
218+
for i in 0..<a.size {
219+
vec[i] = a[i] / b[i]
220+
}
221+
return vec
222+
}
223+
224+
public static func divide(_ a: Matrix<Self>, _ b: Matrix<Self>) -> Matrix<Self> {
225+
var mat = Matrix(like: a)
226+
vDSP.divide(a.buffer, b.buffer, result: &mat.buffer)
227+
return mat
228+
}
229+
230+
public static func divide(_ a: ShapedArray<Double>, _ b: ShapedArray<Double>) -> ShapedArray<Double> {
231+
var arr = ShapedArray<Double>(shape: a.shape)
232+
vDSP.divide(a.buffer, b.buffer, result: &arr.buffer)
233+
return arr
234+
}
235+
}

0 commit comments

Comments
 (0)