-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathautograd.swift
87 lines (71 loc) · 2.24 KB
/
autograd.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class Node: Equatable{
var value: Float
var grad: Float
var children: [Node]
var _backward = {()->Bool in return false}
init(value: Float, children: [Node] = []) {
self.value = value
self.grad = 0
self.children = children
}
static func +(left: Node, right: Node) -> Node{
var out = Node(value: left.value + right.value, children: [left, right])
out._backward = { () -> Bool in
left.grad += out.grad
right.grad += out.grad
return true
}
return out
}
static func *(left:Node, right:Node) -> Node{
var out = Node(value: left.value * right.value, children: [left, right])
out._backward = { () -> Bool in
left.grad += right.value * out.grad
right.grad += left.value * out.grad
return true
}
return out
}
static func /(left:Node, right:Node) -> Node{
let out = Node(value: left.value / right.value, children: [left, right])
out._backward = { () -> Bool in
left.grad += 1.0 / right.value * out.grad
right.grad += -(left.value / pow(right.value, 2) * out.grad)
return true
}
return out
}
func sigmoid() -> Node{
let output = 1/(1+exp(-self.value))
let out = Node(value: output, children: [self])
out._backward = { () -> Bool in
self.grad += out.grad * out.value * (1-out.value)
return true
}
return out
}
func backward(){
var graph:[Node] = []
func findChild(node: Node){
if !graph.contains(node){
for children in node.children{
findChild(node: children)
}
graph.append(node)
}
}
findChild(node: self)
self.grad = 1
for value in graph.reversed(){
_ = value._backward()
}
}
static func == (lhs: Node, rhs: Node) -> Bool {
if (lhs.value==rhs.value) && (lhs.children == rhs.children){
return true
}
else{
return false
}
}
}