-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_splits.py
More file actions
26 lines (17 loc) · 866 Bytes
/
create_splits.py
File metadata and controls
26 lines (17 loc) · 866 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import numpy as np
import os
from RecSysFramework.Utils import invert_dictionary
from RecSysFramework.ExperimentalConfig import EXPERIMENTAL_CONFIG
for splitter in EXPERIMENTAL_CONFIG['splits']:
for dataset_config in EXPERIMENTAL_CONFIG['datasets']:
datareader = dataset_config['datareader']()
postprocessings = dataset_config['postprocessings']
dataset = datareader.load_data(postprocessings=postprocessings)
dataset.save_data()
np.random.seed(42)
train, test, validation = splitter.split(dataset)
splitter.save_split([train, test, validation])
for fold in range(EXPERIMENTAL_CONFIG['n_folds']):
np.random.seed(fold+1)
train, test, validation = splitter.split(dataset)
splitter.save_split([train, test, validation], filename_suffix="_{}".format(fold))