Skip to content

Commit ef33a69

Browse files
authored
Merge branch 'keras-team:master' into master
2 parents 7962b03 + a9fc80d commit ef33a69

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2123
-532
lines changed

.github/workflows/scorecard.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ jobs:
4848
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
4949
# format to the repository Actions tab.
5050
- name: "Upload artifact"
51-
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
51+
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
5252
with:
5353
name: SARIF file
5454
path: results.sarif
5555
retention-days: 5
5656

5757
# Upload the results to GitHub's code scanning dashboard.
5858
- name: "Upload to code-scanning"
59-
uses: github/codeql-action/upload-sarif@b20883b0cd1f46c72ae0ba6d1090936928f9fa30 # v3.29.5
59+
uses: github/codeql-action/upload-sarif@89a39a4e59826350b863aa6b6252a07ad50cf83e # v3.29.5
6060
with:
6161
sarif_file: results.sarif

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@
251251
from keras.src.ops.numpy import moveaxis as moveaxis
252252
from keras.src.ops.numpy import multiply as multiply
253253
from keras.src.ops.numpy import nan_to_num as nan_to_num
254+
from keras.src.ops.numpy import nanargmin as nanargmin
254255
from keras.src.ops.numpy import nancumsum as nancumsum
255256
from keras.src.ops.numpy import nanmax as nanmax
256257
from keras.src.ops.numpy import nanmean as nanmean

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
from keras.src.ops.numpy import moveaxis as moveaxis
134134
from keras.src.ops.numpy import multiply as multiply
135135
from keras.src.ops.numpy import nan_to_num as nan_to_num
136+
from keras.src.ops.numpy import nanargmin as nanargmin
136137
from keras.src.ops.numpy import nancumsum as nancumsum
137138
from keras.src.ops.numpy import nanmax as nanmax
138139
from keras.src.ops.numpy import nanmean as nanmean

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@
251251
from keras.src.ops.numpy import moveaxis as moveaxis
252252
from keras.src.ops.numpy import multiply as multiply
253253
from keras.src.ops.numpy import nan_to_num as nan_to_num
254+
from keras.src.ops.numpy import nanargmin as nanargmin
254255
from keras.src.ops.numpy import nancumsum as nancumsum
255256
from keras.src.ops.numpy import nanmax as nanmax
256257
from keras.src.ops.numpy import nanmean as nanmean

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
from keras.src.ops.numpy import moveaxis as moveaxis
134134
from keras.src.ops.numpy import multiply as multiply
135135
from keras.src.ops.numpy import nan_to_num as nan_to_num
136+
from keras.src.ops.numpy import nanargmin as nanargmin
136137
from keras.src.ops.numpy import nancumsum as nancumsum
137138
from keras.src.ops.numpy import nanmax as nanmax
138139
from keras.src.ops.numpy import nanmean as nanmean

keras/src/backend/jax/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,11 @@ def moveaxis(x, source, destination):
10431043
return jnp.moveaxis(x, source=source, destination=destination)
10441044

10451045

1046+
def nanargmin(x, axis=None, keepdims=False):
1047+
x = convert_to_tensor(x)
1048+
return jnp.nanargmin(x, axis=axis, keepdims=keepdims)
1049+
1050+
10461051
def nancumsum(x, axis=None, dtype=None):
10471052
x = convert_to_tensor(x)
10481053
return jnp.nancumsum(x, axis=axis, dtype=dtype)

keras/src/backend/jax/random.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def _get_concrete_noise_shape(inputs, noise_shape):
6969

7070

7171
def dropout(inputs, rate, noise_shape=None, seed=None):
72+
if rate == 1.0:
73+
return jax.numpy.zeros_like(inputs)
74+
if rate == 0.0:
75+
return inputs
7276
seed = jax_draw_seed(seed)
7377
keep_prob = 1.0 - rate
7478
# The `noise_shape` may contain `None` so we need to convert it

keras/src/backend/jax/trainer.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,13 +916,95 @@ def _get_state_sharding_spec(self):
916916
else:
917917
optimizer_shardings = []
918918
metrics_shardings = [v.value.sharding for v in self.metrics_variables]
919+
920+
self._check_sharding_consistency(
921+
trainable_shardings,
922+
non_trainable_shardings,
923+
optimizer_shardings,
924+
metrics_shardings,
925+
)
926+
919927
return (
920928
trainable_shardings,
921929
non_trainable_shardings,
922930
optimizer_shardings,
923931
metrics_shardings,
924932
)
925933

934+
def _check_sharding_consistency(
935+
self,
936+
trainable_shardings,
937+
non_trainable_shardings,
938+
optimizer_shardings,
939+
metrics_shardings,
940+
):
941+
"""Warn if there is a mix of local and distributed variable shardings.
942+
943+
When some variables have SingleDeviceSharding (created outside the
944+
distribution scope) and others have mesh-aware shardings (created
945+
inside), passing them together as `out_shardings` to `jax.jit`
946+
raises ``ValueError: Received incompatible devices for jitted
947+
computation``. This helper detects the mismatch early and emits
948+
an actionable warning.
949+
"""
950+
if distribution_lib.distribution() is None:
951+
return
952+
953+
var_shard_pairs = itertools.chain(
954+
zip(self.trainable_variables, trainable_shardings),
955+
zip(self.non_trainable_variables, non_trainable_shardings),
956+
zip(
957+
(
958+
self.optimizer.variables
959+
if hasattr(self, "optimizer") and self.optimizer
960+
else []
961+
),
962+
optimizer_shardings,
963+
),
964+
zip(self.metrics_variables, metrics_shardings),
965+
)
966+
967+
first_local_var_path = None
968+
has_mesh = False
969+
for v, s in var_shard_pairs:
970+
if isinstance(s, jax.sharding.SingleDeviceSharding):
971+
if first_local_var_path is None:
972+
first_local_var_path = v.path
973+
else:
974+
has_mesh = True
975+
# Early exit: we know there is a mix as soon as we have
976+
# seen at least one of each kind.
977+
if first_local_var_path and has_mesh:
978+
break
979+
980+
if not (first_local_var_path and has_mesh):
981+
return
982+
983+
warnings.warn(
984+
"Detected a mix of local (SingleDeviceSharding) and "
985+
"distributed (mesh-aware) variables. This will cause "
986+
"a 'ValueError: Received incompatible devices for "
987+
"jitted computation' when JAX tries to compile the "
988+
"training step.\n\n"
989+
f"First local variable found: {first_local_var_path!r}\n\n"
990+
"This typically happens when the model is built or "
991+
"weights are loaded before the distribution is set. "
992+
"To fix this, call set_distribution() before creating "
993+
"any Keras objects:\n\n"
994+
" import keras\n"
995+
" keras.distribution.set_distribution(distribution)\n"
996+
" model = create_model()\n"
997+
" model.compile(...)\n"
998+
" model.fit(...)\n\n"
999+
"Alternatively, use the distribution scope context "
1000+
"manager:\n\n"
1001+
" with distribution.scope():\n"
1002+
" model = create_model()\n"
1003+
" model.compile(...)\n"
1004+
" model.fit(...)\n",
1005+
stacklevel=3,
1006+
)
1007+
9261008
def _purge_model_variables(
9271009
self,
9281010
trainable_variables=False,
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import warnings
2+
3+
import numpy as np
4+
from absl.testing import parameterized
5+
6+
from keras.src import backend
7+
from keras.src import layers
8+
from keras.src import models
9+
from keras.src import testing
10+
from keras.src.backend import distribution_lib as backend_dlib
11+
from keras.src.distribution import distribution_lib
12+
13+
14+
class JAXTrainerTest(testing.TestCase, parameterized.TestCase):
15+
def _skip_if_not_distributed(self):
16+
if backend.backend() != "jax":
17+
self.skipTest("Requires JAX backend")
18+
if len(backend_dlib.list_devices()) < 2:
19+
self.skipTest("Requires at least 2 devices")
20+
21+
def _make_distribution(self, dist_type):
22+
if dist_type == "data_parallel":
23+
return distribution_lib.DataParallel()
24+
devices = backend_dlib.list_devices()
25+
n = len(devices)
26+
mesh = distribution_lib.DeviceMesh((n,), ["model"], devices)
27+
layout_map = distribution_lib.LayoutMap(mesh)
28+
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
29+
[None, "model"]
30+
)
31+
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])
32+
return distribution_lib.ModelParallel(layout_map=layout_map)
33+
34+
# ----------------------------------------------------------------
35+
# Mixed-sharding warning tests
36+
# ----------------------------------------------------------------
37+
@parameterized.named_parameters(
38+
{"testcase_name": "DataParallel", "dist_type": "data_parallel"},
39+
{"testcase_name": "ModelParallel", "dist_type": "model_parallel"},
40+
)
41+
def test_warns_when_model_built_outside_scope(self, dist_type):
42+
"""Model built outside distribution -> mixed warning on compile."""
43+
self._skip_if_not_distributed()
44+
import jax
45+
46+
n = len(backend_dlib.list_devices())
47+
units = n * max(1, 4 // n)
48+
dist = self._make_distribution(dist_type)
49+
50+
# Model created outside any distribution scope — weights are local.
51+
model = models.Sequential([layers.Dense(units, input_shape=(16,))])
52+
53+
for w in model.weights:
54+
self.assertIsInstance(
55+
w.value.sharding, jax.sharding.SingleDeviceSharding
56+
)
57+
58+
inputs = np.random.normal(size=(8, 16)).astype("float32")
59+
labels = np.random.normal(size=(8, units)).astype("float32")
60+
61+
with dist.scope():
62+
model.compile(loss="mse", optimizer="adam")
63+
with warnings.catch_warnings(record=True) as caught:
64+
warnings.simplefilter("always")
65+
model._symbolic_build(data_batch=(inputs[:2], labels[:2]))
66+
model._get_state_sharding_spec()
67+
68+
mixed = [w for w in caught if "mix of local" in str(w.message)]
69+
self.assertGreater(
70+
len(mixed),
71+
0,
72+
"Expected a mixed-sharding warning but none was raised",
73+
)
74+
msg = str(mixed[0].message)
75+
self.assertIn("SingleDeviceSharding", msg)
76+
self.assertIn("set_distribution", msg)
77+
78+
@parameterized.named_parameters(
79+
{"testcase_name": "DataParallel", "dist_type": "data_parallel"},
80+
{"testcase_name": "ModelParallel", "dist_type": "model_parallel"},
81+
)
82+
def test_no_warning_when_model_built_inside_scope(self, dist_type):
83+
"""Model built inside distribution scope -> no warning."""
84+
self._skip_if_not_distributed()
85+
86+
n = len(backend_dlib.list_devices())
87+
units = n * max(1, 4 // n)
88+
dist = self._make_distribution(dist_type)
89+
90+
# Model created inside scope — weights get proper sharding.
91+
with dist.scope():
92+
model = models.Sequential([layers.Dense(units, input_shape=(16,))])
93+
94+
inputs = np.random.normal(size=(8, 16)).astype("float32")
95+
labels = np.random.normal(size=(8, units)).astype("float32")
96+
97+
with dist.scope():
98+
model.compile(loss="mse", optimizer="adam")
99+
with warnings.catch_warnings(record=True) as caught:
100+
warnings.simplefilter("always")
101+
model._symbolic_build(data_batch=(inputs[:2], labels[:2]))
102+
model._get_state_sharding_spec()
103+
104+
mixed = [w for w in caught if "mix of local" in str(w.message)]
105+
self.assertEqual(
106+
len(mixed),
107+
0,
108+
"Unexpected mixed-sharding warning when model is "
109+
"built inside scope",
110+
)

keras/src/backend/numpy/core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,8 @@ def scatter(indices, values, shape):
330330
indices = np.reshape(indices, [-1, index_length])
331331
values = np.reshape(values, [-1] + list(value_shape))
332332

333-
for i in range(indices.shape[0]):
334-
index = indices[i]
335-
zeros[tuple(index)] += values[i]
333+
idx = tuple(indices.T)
334+
np.add.at(zeros, idx, values)
336335
return zeros
337336

338337

0 commit comments

Comments
 (0)