Skip to content

Commit 702a76f

Browse files
authored
Create an XLA parameter and fix the mixed precision (#7311)
* Create an XLA parameter and fix mixed precision creation * Fix issue brought by intellisense * Complete docstring
1 parent 596342c commit 702a76f

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

src/transformers/trainer_tf.py

-4
Original file line numberDiff line numberDiff line change
@@ -531,10 +531,6 @@ def train(self) -> None:
531531

532532
tf.summary.experimental.set_step(self.global_step)
533533

534-
if self.args.fp16:
535-
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
536-
tf.keras.mixed_precision.experimental.set_policy(policy)
537-
538534
with self.tb_writer.as_default():
539535
tf.summary.text("args", self.args.to_json_string())
540536

src/transformers/training_args_tf.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class TFTrainingArguments(TrainingArguments):
8888
tpu_num_cores (:obj:`int`, `optional`):
8989
When training on TPU, the mumber of TPU cores (automatically passed by launcher script).
9090
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
91-
Wheter to activate the trace to record computation graphs and profiling information or not.
91+
Whether to activate the trace to record computation graphs and profiling information or not.
9292
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
9393
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
9494
or not.
@@ -103,19 +103,32 @@ class TFTrainingArguments(TrainingArguments):
103103
The name of the TPU the process is running on.
104104
run_name (:obj:`str`, `optional`):
105105
A descriptor for the run. Notably used for wandb logging.
106+
xla (:obj:`bool`, `optional`):
107+
Whether to activate the XLA compilation or not.
106108
"""
107109

108110
tpu_name: str = field(
109111
default=None,
110112
metadata={"help": "Name of TPU"},
111113
)
112114

115+
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
116+
113117
@cached_property
114118
@tf_required
115119
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
116120
logger.info("Tensorflow: setting up strategy")
121+
122+
if self.args.xla:
123+
tf.config.optimizer.set_jit(True)
124+
117125
gpus = tf.config.list_physical_devices("GPU")
118126

127+
# Set to float16 at first
128+
if self.fp16:
129+
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
130+
tf.keras.mixed_precision.experimental.set_policy(policy)
131+
119132
if self.no_cuda:
120133
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
121134
else:
@@ -128,10 +141,16 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
128141
tpu = None
129142

130143
if tpu:
144+
# Set to bfloat16 in case of TPU
145+
if self.fp16:
146+
policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
147+
tf.keras.mixed_precision.experimental.set_policy(policy)
148+
131149
tf.config.experimental_connect_to_cluster(tpu)
132150
tf.tpu.experimental.initialize_tpu_system(tpu)
133151

134152
strategy = tf.distribute.experimental.TPUStrategy(tpu)
153+
135154
elif len(gpus) == 0:
136155
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
137156
elif len(gpus) == 1:

0 commit comments

Comments
 (0)