Skip to content

Commit 3628982

Browse files
committed
update optimization test
1 parent e94f256 commit 3628982

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

test/model/test_optimization.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1+
import os
2+
import shutil
13
import tensorflow as tf
24
from detext.train import optimization
35

46

7+
tmp_out_dir = "/tmp/detext-optimization-test"
8+
9+
510
class OptimizationTest(tf.test.TestCase):
611

12+
def _cleanUp(self, tmp_out_dir):
13+
if os.path.exists(tmp_out_dir):
14+
shutil.rmtree(tmp_out_dir, ignore_errors=True)
15+
716
def test_adam(self):
817
with self.test_session() as sess:
918
w = tf.get_variable(
@@ -45,8 +54,12 @@ def test_different_lr(self):
4554
num_warmup_steps=0,
4655
lr_bert=0.00001,
4756
optimizer="bert_adam",
48-
use_horovod=False
57+
use_horovod=False,
58+
ftr_ext='cnn',
59+
out_dir=tmp_out_dir
4960
)
61+
if not tf.gfile.Exists(hparams.out_dir):
62+
tf.gfile.MakeDirs(hparams.out_dir)
5063
train_op, _, _ = optimization.create_optimizer(hparams, loss)
5164

5265
init_op = tf.group(tf.global_variables_initializer(),
@@ -57,6 +70,7 @@ def test_different_lr(self):
5770
print(bert_w_v, non_bert_w_v)
5871
# The difference of weight values (gradient) reflects the learning arte difference
5972
self.assertAllClose((bert_w_v - 0.1) / (non_bert_w_v - 0.1), [0.01, 0.01, 0.01], rtol=1e-2, atol=1e-2)
73+
self._cleanUp(tmp_out_dir)
6074

6175

6276
if __name__ == "__main__":

0 commit comments

Comments
 (0)