forked from dereklstinson/gocudnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcudnnRNN_algofindbwf.go
124 lines (111 loc) · 3.29 KB
/
cudnnRNN_algofindbwf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package gocudnn
/*
#include <cudnn.h>
*/
import "C"
import (
"unsafe"
"github.com/dereklstinson/cutil"
)
//GetRNNBackwardWeightsAlgorithmMaxCount gets the max number of Algorithm for weights
func (r *RNND) getRNNBackwardWeightsAlgorithmMaxCount(handle *Handle) (int32, error) {
var count C.int
err := Status(C.cudnnGetRNNBackwardWeightsAlgorithmMaxCount(
handle.x,
r.descriptor,
&count,
)).error("GetRNNBackwardWeightsAlgorithmMaxCount")
return int32(count), err
}
//FindRNNBackwardWeightsAlgorithmEx returns some Algorithm and their performance and stuff
func (r *RNND) FindRNNBackwardWeightsAlgorithmEx(
handle *Handle,
xD []*TensorD, x cutil.Mem,
hxD *TensorD, hx cutil.Mem,
yD []*TensorD, y cutil.Mem,
findIntensity float32, //unused for future use
wspace cutil.Mem, wspacesize uint,
dwD *FilterD, dw cutil.Mem,
rspace cutil.Mem, rspacesize uint,
) ([]AlgorithmPerformance, error) {
reqAlgocount, err := r.getRNNBackwardWeightsAlgorithmMaxCount(handle)
if err != nil {
return nil, err
}
seqLength := (C.int)(len(xD))
var actualcount C.int
inCxD := tensorDArrayToC(xD)
inCyD := tensorDArrayToC(yD)
perfresults := make([]C.cudnnAlgorithmPerformance_t, reqAlgocount)
if wspace == nil {
err = Status(C.cudnnFindRNNBackwardWeightsAlgorithmEx(
handle.x,
r.descriptor,
seqLength,
&inCxD[0], x.Ptr(),
hxD.descriptor, hx.Ptr(),
&inCyD[0], y.Ptr(),
C.float(findIntensity),
C.int(reqAlgocount),
&actualcount,
&perfresults[0],
nil, C.size_t(0),
dwD.descriptor, dw.Ptr(),
rspace.Ptr(), C.size_t(rspacesize),
)).error("FindRNNBackwardWeightsAlgorithmEx")
return calgoperftogoarray(perfresults, handle.gogc), err
}
err = Status(C.cudnnFindRNNBackwardWeightsAlgorithmEx(
handle.x,
r.descriptor,
seqLength,
&inCxD[0], x.Ptr(),
hxD.descriptor, hx.Ptr(),
&inCyD[0], y.Ptr(),
C.float(findIntensity),
C.int(reqAlgocount),
&actualcount,
&perfresults[0],
wspace.Ptr(), C.size_t(wspacesize),
dwD.descriptor, dw.Ptr(),
rspace.Ptr(), C.size_t(rspacesize),
)).error("FindRNNBackwardWeightsAlgorithmEx")
return calgoperftogoarray(perfresults, handle.gogc), err
}
//FindRNNBackwardWeightsAlgorithmExUS is like FindRNNBackwardWeightsAlgorithmEx but uses unsafe.Pointer instead of cutil.Mem
func (r *RNND) FindRNNBackwardWeightsAlgorithmExUS(
handle *Handle,
xD []*TensorD, x unsafe.Pointer,
hxD *TensorD, hx unsafe.Pointer,
yD []*TensorD, y unsafe.Pointer,
findIntensity float32, //unused for future use
wspace unsafe.Pointer, wspacesize uint,
dwD *FilterD, dw unsafe.Pointer,
rspace unsafe.Pointer, rspacesize uint,
) ([]AlgorithmPerformance, error) {
reqAlgocount, err := r.getRNNBackwardWeightsAlgorithmMaxCount(handle)
if err != nil {
return nil, err
}
seqLength := (C.int)(len(xD))
var actualcount C.int
inCxD := tensorDArrayToC(xD)
inCyD := tensorDArrayToC(yD)
perfresults := make([]C.cudnnAlgorithmPerformance_t, reqAlgocount)
err = Status(C.cudnnFindRNNBackwardWeightsAlgorithmEx(
handle.x,
r.descriptor,
seqLength,
&inCxD[0], x,
hxD.descriptor, hx,
&inCyD[0], y,
C.float(findIntensity),
C.int(reqAlgocount),
&actualcount,
&perfresults[0],
wspace, C.size_t(wspacesize),
dwD.descriptor, dw,
rspace, C.size_t(rspacesize),
)).error("FindRNNBackwardWeightsAlgorithmEx")
return calgoperftogoarray(perfresults, handle.gogc), err
}