forked from dereklstinson/gocudnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcudnnRNN_clip.go
41 lines (33 loc) · 1.28 KB
/
cudnnRNN_clip.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
package gocudnn
/*
#include <cudnn.h>
*/
import "C"
//RNNClipMode is a flag for the clipmode for an RNN
type RNNClipMode C.cudnnRNNClipMode_t
func (r RNNClipMode) c() C.cudnnRNNClipMode_t {
return C.cudnnRNNClipMode_t(r)
}
//None sets r to and returns RNNClipMode(C.CUDNN_RNN_CLIP_NONE)
func (r *RNNClipMode) None() RNNClipMode { *r = RNNClipMode(C.CUDNN_RNN_CLIP_NONE); return *r }
//MinMax sets r to and returns RNNClipMode(C.CUDNN_RNN_CLIP_MINMAX)
func (r *RNNClipMode) MinMax() RNNClipMode { *r = RNNClipMode(C.CUDNN_RNN_CLIP_MINMAX); return *r }
//SetClip sets the clip mode into descriptor
func (r *RNND) SetClip(h *Handle, mode RNNClipMode, nanprop NANProp, lclip, rclip float64) error {
return Status(C.cudnnRNNSetClip(h.x, r.descriptor, mode.c(), nanprop.c(), C.double(lclip), C.double(rclip))).error("SetClip")
}
//GetClip returns the clip settings for the descriptor
func (r *RNND) GetClip(h *Handle) (mode RNNClipMode, nanprop NANProp, lclip, rclip float64, err error) {
var (
m C.cudnnRNNClipMode_t
nan C.cudnnNanPropagation_t
lt C.double
rt C.double
)
err = Status(C.cudnnRNNGetClip(h.x, r.descriptor, &m, &nan, <, &rt)).error("SetClip")
mode = RNNClipMode(m)
nanprop = NANProp(nan)
lclip = float64(lt)
rclip = float64(rt)
return mode, nanprop, lclip, rclip, err
}