forked from dereklstinson/gocudnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcudnnSoftMax.go
169 lines (145 loc) · 4.2 KB
/
cudnnSoftMax.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
package gocudnn
/*
#include <cudnn.h>
*/
import "C"
import (
"unsafe"
"github.com/dereklstinson/cutil"
)
//SoftMaxD holds the soft max flags and soft max funcs
type SoftMaxD struct {
set bool
algo C.cudnnSoftmaxAlgorithm_t
mode C.cudnnSoftmaxMode_t
}
//CreateSoftMaxDescriptor creates a gocudnn softmax descriptor. It is not part of cudnn, but I wanted to make the library
//A little more stream lined after using it for a while
func CreateSoftMaxDescriptor() *SoftMaxD {
return &SoftMaxD{}
}
//Set sets the soft max algos.
func (s *SoftMaxD) Set(algo SoftMaxAlgorithm, mode SoftMaxMode) error {
s.algo = algo.c()
s.mode = mode.c()
return nil
}
//Get gets the softmax descriptor values
func (s *SoftMaxD) Get() (algo SoftMaxAlgorithm, mode SoftMaxMode, err error) {
return SoftMaxAlgorithm(s.algo), SoftMaxMode(s.mode), nil
}
/* Softmax functions: All of the form "output = alpha * Op(inputs) + beta * output" */
//Forward performs forward softmax
//
//Input/Output: y
func (s *SoftMaxD) Forward(
handle *Handle,
alpha float64,
xD *TensorD, x cutil.Mem,
beta float64,
yD *TensorD, y cutil.Mem) error {
a := cscalarbydatatype(xD.dtype, alpha)
b := cscalarbydatatype(yD.dtype, beta)
return Status(C.cudnnSoftmaxForward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
xD.descriptor, x.Ptr(),
b.CPtr(),
yD.descriptor, y.Ptr(),
)).error("SoftMaxForward")
}
//Backward performs the backward softmax
//
//Input/Output: dx
func (s *SoftMaxD) Backward(
handle *Handle,
alpha float64,
yD *TensorD, y cutil.Mem,
dyD *TensorD, dy cutil.Mem,
beta float64,
dxD *TensorD, dx cutil.Mem,
) error {
a := cscalarbydatatype(yD.dtype, alpha)
b := cscalarbydatatype(dxD.dtype, beta)
return Status(C.cudnnSoftmaxBackward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
yD.descriptor, y.Ptr(),
dyD.descriptor, dy.Ptr(),
b.CPtr(),
dxD.descriptor, dx.Ptr(),
)).error("SoftMaxBackward")
}
//ForwardUS is like Forward but uses unsafe.Pointer instead of cutil.Mem
func (s *SoftMaxD) ForwardUS(
handle *Handle,
alpha float64,
xD *TensorD, x unsafe.Pointer,
beta float64,
yD *TensorD, y unsafe.Pointer) error {
a := cscalarbydatatype(xD.dtype, alpha)
b := cscalarbydatatype(yD.dtype, beta)
return Status(C.cudnnSoftmaxForward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
xD.descriptor, x,
b.CPtr(),
yD.descriptor, y,
)).error("SoftMaxForward")
}
//BackwardUS is like Backward but uses unsafe.Pointer instead of cutil.Mem
func (s *SoftMaxD) BackwardUS(
handle *Handle,
alpha float64,
yD *TensorD, y unsafe.Pointer,
dyD *TensorD, dy unsafe.Pointer,
beta float64,
dxD *TensorD, dx unsafe.Pointer,
) error {
a := cscalarbydatatype(yD.dtype, alpha)
b := cscalarbydatatype(dxD.dtype, beta)
return Status(C.cudnnSoftmaxBackward(
handle.x,
s.algo,
s.mode,
a.CPtr(),
yD.descriptor, y,
dyD.descriptor, dy,
b.CPtr(),
dxD.descriptor, dx,
)).error("SoftMaxBackward")
}
//SoftMaxAlgorithm is used for flags and are exposed through its methods
type SoftMaxAlgorithm C.cudnnSoftmaxAlgorithm_t
//Fast changes s to and returns SoftMaxAlgorithm(C.CUDNN_SOFTMAX_FAST)
func (s *SoftMaxAlgorithm) Fast() SoftMaxAlgorithm {
*s = SoftMaxAlgorithm(C.CUDNN_SOFTMAX_FAST)
return *s
}
//Accurate changes s to and returns SoftMaxAlgorithm(C.CUDNN_SOFTMAX_ACCURATE)
func (s *SoftMaxAlgorithm) Accurate() SoftMaxAlgorithm {
*s = SoftMaxAlgorithm(C.CUDNN_SOFTMAX_ACCURATE)
return *s
}
//Log changes s to and returns SoftMaxAlgorithm(C.CUDNN_SOFTMAX_LOG)
func (s *SoftMaxAlgorithm) Log() SoftMaxAlgorithm {
*s = SoftMaxAlgorithm(C.CUDNN_SOFTMAX_LOG)
return *s
}
func (s SoftMaxAlgorithm) c() C.cudnnSoftmaxAlgorithm_t { return C.cudnnSoftmaxAlgorithm_t(s) }
//SoftMaxMode is used for softmaxmode flags and are exposed through its methods
type SoftMaxMode C.cudnnSoftmaxMode_t
//Instance changes s to SoftMaxMode(C.CUDNN_SOFTMAX_MODE_INSTANCE) and returns changed value
func (s *SoftMaxMode) Instance() SoftMaxMode {
*s = SoftMaxMode(C.CUDNN_SOFTMAX_MODE_INSTANCE)
return *s
}
//Channel changes s to SoftMaxMode(C.CUDNN_SOFTMAX_MODE_CHANNEL) and returns changed value
func (s *SoftMaxMode) Channel() SoftMaxMode { *s = SoftMaxMode(C.CUDNN_SOFTMAX_MODE_CHANNEL); return *s }
func (s SoftMaxMode) c() C.cudnnSoftmaxMode_t { return C.cudnnSoftmaxMode_t(s) }