Skip to content

Commit 096b848

Browse files
authored
Allow bfloat16 default dtype (#19074)
Useful for llms! The tradeoff in precision can often be worth it in memory constrained environments, and unlike float16, does not have the same overflow/underflow issues during training.
1 parent 2feb430 commit 096b848

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

keras/backend/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
def floatx():
2121
"""Return the default float type, as a string.
2222
23-
E.g. `'float16'`, `'float32'`, `'float64'`.
23+
E.g. `'bfloat16'`, `'float16'`, `'float32'`, `'float64'`.
2424
2525
Returns:
2626
String, the current default float type.
@@ -45,7 +45,7 @@ def set_floatx(value):
4545
`keras.mixed_precision.set_dtype_policy('mixed_float16')`.
4646
4747
Args:
48-
value: String; `'float16'`, `'float32'`, or `'float64'`.
48+
value: String; `'bfloat16'`, `'float16'`, `'float32'`, or `'float64'`.
4949
5050
Examples:
5151
>>> keras.config.floatx()
@@ -62,7 +62,7 @@ def set_floatx(value):
6262
ValueError: In case of invalid value.
6363
"""
6464
global _FLOATX
65-
accepted_dtypes = {"float16", "float32", "float64"}
65+
accepted_dtypes = {"bfloat16", "float16", "float32", "float64"}
6666
if value not in accepted_dtypes:
6767
raise ValueError(
6868
f"Unknown `floatx` value: {value}. "

0 commit comments

Comments
 (0)