-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathloadDataset.py
More file actions
56 lines (43 loc) · 1.9 KB
/
loadDataset.py
File metadata and controls
56 lines (43 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding: utf-8 -*-
import torch
import numpy as np
import pandas as pd
from parameters import *
from normalization import Normalization
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset, TensorDataset, DataLoader
def getDataset():
######################################################
dir = './'
######################################################
data = pd.read_csv(dir+'data.csv')
######################################################
data.volFrac = data.volFrac.astype(float)
data.thetaX = data.thetaX.astype(float)
data.thetaY = data.thetaY.astype(float)
data.thetaZ = data.thetaZ.astype(float)
######################################################
print('Data: ',data.shape)
##############---INIT TENSORS---##############
featureTensor = torch.tensor(data[featureNames].values)
labelTensor = torch.tensor(data[labelNames].values)
##############---INIT NORMALIZATION---##############
featureNormalization = Normalization(featureTensor)
featureTensor = featureNormalization.normalize(featureTensor)
##############---INIT Dataset and loader---##############
dataset = TensorDataset(featureTensor.float(), labelTensor.float())
l1 = round(len(dataset)*trainSplit)
l2 = len(dataset) - l1
print('train/test: ',[l1,l2])
train_set, test_set = torch.utils.data.random_split(dataset, [l1,l2])
return train_set, test_set, featureNormalization
#################################################
def exportTensor(name,data,cols, header=True):
df=pd.DataFrame.from_records(data.detach().numpy())
if(header):
df.columns = cols
print(name)
df.to_csv(name+".csv",header=header)
def exportList(name,data):
arr=np.array(data)
np.savetxt(name+".csv", [arr], delimiter=',')