@@ -88,7 +88,7 @@ class TFTrainingArguments(TrainingArguments):
88
88
tpu_num_cores (:obj:`int`, `optional`):
89
89
When training on TPU, the mumber of TPU cores (automatically passed by launcher script).
90
90
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
91
- Wheter to activate the trace to record computation graphs and profiling information or not.
91
+ Whether to activate the trace to record computation graphs and profiling information or not.
92
92
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
93
93
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
94
94
or not.
@@ -103,19 +103,32 @@ class TFTrainingArguments(TrainingArguments):
103
103
The name of the TPU the process is running on.
104
104
run_name (:obj:`str`, `optional`):
105
105
A descriptor for the run. Notably used for wandb logging.
106
+ xla (:obj:`bool`, `optional`):
107
+ Whether to activate the XLA compilation or not.
106
108
"""
107
109
108
110
tpu_name : str = field (
109
111
default = None ,
110
112
metadata = {"help" : "Name of TPU" },
111
113
)
112
114
115
+ xla : bool = field (default = False , metadata = {"help" : "Whether to activate the XLA compilation or not" })
116
+
113
117
@cached_property
114
118
@tf_required
115
119
def _setup_strategy (self ) -> Tuple ["tf.distribute.Strategy" , int ]:
116
120
logger .info ("Tensorflow: setting up strategy" )
121
+
122
+ if self .args .xla :
123
+ tf .config .optimizer .set_jit (True )
124
+
117
125
gpus = tf .config .list_physical_devices ("GPU" )
118
126
127
+ # Set to float16 at first
128
+ if self .fp16 :
129
+ policy = tf .keras .mixed_precision .experimental .Policy ("mixed_float16" )
130
+ tf .keras .mixed_precision .experimental .set_policy (policy )
131
+
119
132
if self .no_cuda :
120
133
strategy = tf .distribute .OneDeviceStrategy (device = "/cpu:0" )
121
134
else :
@@ -128,10 +141,16 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
128
141
tpu = None
129
142
130
143
if tpu :
144
+ # Set to bfloat16 in case of TPU
145
+ if self .fp16 :
146
+ policy = tf .keras .mixed_precision .experimental .Policy ("mixed_bfloat16" )
147
+ tf .keras .mixed_precision .experimental .set_policy (policy )
148
+
131
149
tf .config .experimental_connect_to_cluster (tpu )
132
150
tf .tpu .experimental .initialize_tpu_system (tpu )
133
151
134
152
strategy = tf .distribute .experimental .TPUStrategy (tpu )
153
+
135
154
elif len (gpus ) == 0 :
136
155
strategy = tf .distribute .OneDeviceStrategy (device = "/cpu:0" )
137
156
elif len (gpus ) == 1 :
0 commit comments