forked from dereklstinson/gocudnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathzhelperfuncs.go
101 lines (95 loc) · 2 KB
/
zhelperfuncs.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package gocudnn
/*
#include <cudnn.h>
*/
import "C"
func int32Tocint(x []int32) []C.int {
y := make([]C.int, len(x))
for i := 0; i < len(x); i++ {
y[i] = C.int(x[i])
}
return y
}
func cintToint32(x []C.int) []int32 {
y := make([]int32, len(x))
for i := 0; i < len(x); i++ {
y[i] = int32(x[i])
}
return y
}
func comparedims(dims ...[]int32) bool {
totallength := len(dims)
if totallength == 1 {
return true
}
for i := 1; i < totallength; i++ {
if len(dims[0]) != len(dims[i]) {
return false
}
for j := 0; j < len(dims[0]); j++ {
if dims[0][j] != dims[i][j] {
return false
}
}
}
return true
}
func findvolume(dims []int32) int32 {
mult := int32(1)
for i := range dims {
mult *= dims[i]
}
return mult
}
func stridecalc(dims []int32) []int32 {
strides := make([]int32, len(dims))
stride := int32(1)
for i := len(dims) - 1; i >= 0; i-- {
strides[i] = stride
stride *= dims[i]
}
return strides
}
//FindLength returns the length of of the array considering the number of bytes and the Datatype
func FindLength(s uint, dtype DataType) uint32 {
var dflg DataType
var size uint32
switch dtype {
case dflg.Float():
size = uint32(s / (4))
case dflg.Double():
size = uint32(s / (8))
case dflg.Int32():
size = uint32(s / (4))
case dflg.Int8():
size = uint32(s / (1))
case dflg.UInt8():
size = uint32(s / (1))
case dflg.Half():
size = uint32(s / 2)
default:
size = 0
}
return size
}
//FindSizeTfromVol takes a volume of dims and returns the size in bytes in SizeT
func FindSizeTfromVol(volume []int32, dtype DataType) uint {
vol := int32(1)
for i := int32(0); i < int32(len(volume)); i++ {
vol *= volume[i]
}
switch dtype {
case DataType(C.CUDNN_DATA_FLOAT):
return uint(vol * int32(4))
case DataType(C.CUDNN_DATA_DOUBLE):
return uint(vol * int32(8))
case DataType(C.CUDNN_DATA_INT8):
return uint(vol)
case DataType(C.CUDNN_DATA_HALF):
return uint(vol * 2)
case DataType(C.CUDNN_DATA_INT32):
return uint(vol * int32(4))
default:
return uint(0)
}
}