-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_bucket_tensorboard_callback.py
117 lines (93 loc) · 3.75 KB
/
test_bucket_tensorboard_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import unittest
from os import listdir
from os.path import exists
from random import choice
from shutil import rmtree
from string import ascii_lowercase
from sys import argv
from tempfile import gettempdir
from google.cloud import storage
from keras.layers import Dense
from keras.models import Sequential
from numpy.random import rand, randint
from keras_bucket_tensorboard_callback import BucketTensorBoard
class TestBucketTensorboardCallback(unittest.TestCase):
def test_logs_uploaded_gcp(self):
# The callback we need to test
bucket_uri = f"{self.bucket_uri.strip('/')}/{''.join([choice(ascii_lowercase) for i in range(10)])}"
bucket_tb = BucketTensorBoard(bucket_uri)
# Make sure that the temp dicretory does not exist
logs_dir = f'{gettempdir()}/tensorboard_callbacks/{bucket_uri[5:]}'
if exists(logs_dir):
rmtree(logs_dir, ignore_errors=True)
# Make sure that the cloud dicretory does not exist
for blob in bucket_tb.bucket.list_blobs(prefix=bucket_uri[5:].split('/', 1)[1]):
blob.delete()
# Compile the model
self.model.compile(
optimizer='sgd',
loss='binary_crossentropy',
metrics=['accuracy']
)
# Train the model
history = self.model.fit(
x=self.X,
y=self.Y,
verbose=0,
epochs=3,
callbacks=[bucket_tb]
)
# Check if the model was trained for 3 epochs, so we can check if
# tensorboard is showing all 3 epochs
self.assertEqual(len(history.history['loss']), 3)
self.assertEqual(len(history.history['acc']), 3)
# Check if the temp directory has one file
self.assertEqual(1, len(listdir(logs_dir)))
# Check if the cloud directory has one file
self.assertEqual(1, len(list(bucket_tb.bucket.list_blobs(
prefix=bucket_uri[5:].split('/', 1)[1]))))
# Train the model again
self.model.fit(
x=self.X,
y=self.Y,
verbose=0,
epochs=3,
callbacks=[bucket_tb]
)
# Check if the temp directory has two files
self.assertEqual(2, len(listdir(logs_dir)))
# Check if the cloud directory has two files
self.assertEqual(2, len(list(bucket_tb.bucket.list_blobs(
prefix=bucket_uri[5:].split('/', 1)[1]))))
# Erase the logs dir
logs_dir = f'{gettempdir()}/tensorboard_callbacks/{bucket_uri[5:]}'
if exists(logs_dir):
rmtree(logs_dir, ignore_errors=True)
# Erase the bucket created files
for blob in bucket_tb.bucket.list_blobs(prefix=bucket_uri[5:].split('/', 1)[1]):
blob.delete()
# Check if the temp directory has no files
self.assertFalse(exists(logs_dir))
# Check if the cloud directory has no files
self.assertEqual(0, len(list(bucket_tb.bucket.list_blobs(
prefix=bucket_uri[5:].split('/', 1)[1]))))
def test_logs_uploaded_aws(self):
pass
if __name__ == '__main__':
if len(argv) == 2:
TestBucketTensorboardCallback.bucket_uri = argv.pop()
# Simple model for testing
TestBucketTensorboardCallback.model = Sequential([
Dense(10, activation='relu', input_shape=(10,)),
Dense(10, activation='relu'),
Dense(1, activation='sigmoid'),
])
# The data we generate for testing
dataset_size = 100
TestBucketTensorboardCallback.X = rand(dataset_size, 10) * 10
TestBucketTensorboardCallback.Y = randint(2, size=dataset_size)
# Begin testing
unittest.main()
else:
print('Usage: python test_bucket_tensorboard_callback.py bucket_uri')
exit()