Skip to content

Commit a86d51d

Browse files
committed
use legacy keras
1 parent 1989452 commit a86d51d

File tree

6 files changed

+13
-17
lines changed

6 files changed

+13
-17
lines changed

.github/workflows/pypi_release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
pip install "numpy<1.24" "click<8.3.0"
6767
pip install "pydantic<2.0"
6868
pip install torch --index-url https://download.pytorch.org/whl/cpu
69-
pip install pyarrow "ray[train,default]==${{ env.RAY_VERSION }}" tqdm pytest tensorflow==2.13.1 tabulate grpcio-tools wget
69+
pip install pyarrow "ray[train,default]==${{ env.RAY_VERSION }}" tqdm pytest tensorflow==2.16.1 tf_keras tabulate grpcio-tools wget
7070
pip install "xgboost_ray[default]<=0.1.13"
7171
pip install "xgboost<=2.0.3"
7272
pip install torchmetrics

.github/workflows/ray_nightly_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ jobs:
9090
pip install "ray[train,default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp311-cp311-manylinux2014_x86_64.whl"
9191
;;
9292
esac
93-
pip install pyarrow tqdm pytest "tensorflow>=2.16.1,<2.19" tabulate grpcio-tools wget
93+
pip install pyarrow tqdm pytest tabulate grpcio-tools wget
9494
pip install "xgboost_ray[default]<=0.1.13"
9595
pip install torchmetrics
9696
HOROVOD_WITH_GLOO=1
@@ -107,7 +107,7 @@ jobs:
107107
run: |
108108
pip install pyspark==${{ matrix.spark-version }}
109109
./build.sh
110-
pip install dist/raydp-*.whl
110+
pip install "$(ls dist/raydp-*.whl)[tensorflow]"
111111
- name: Lint
112112
run: |
113113
pip install pylint==3.2.7

.github/workflows/raydp.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
else
8383
pip install torch
8484
fi
85-
pip install pyarrow "ray[train,default]==${{ matrix.ray-version }}" tqdm pytest "tensorflow>=2.16.1,<2.19" tabulate grpcio-tools wget
85+
pip install pyarrow "ray[train,default]==${{ matrix.ray-version }}" tqdm pytest tabulate grpcio-tools wget
8686
pip install "xgboost_ray[default]<=0.1.13"
8787
pip install "xgboost<=2.0.3"
8888
pip install torchmetrics
@@ -97,7 +97,7 @@ jobs:
9797
run: |
9898
pip install pyspark==${{ matrix.spark-version }}
9999
./build.sh
100-
pip install dist/raydp-*.whl
100+
pip install "$(ls dist/raydp-*.whl)[tensorflow]"
101101
- name: Lint
102102
run: |
103103
pip install pylint==3.2.7

examples/tensorflow_titanic.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"source": [
1616
"import ray\n",
1717
"import os\n",
18+
"os.environ[\"TF_USE_LEGACY_KERAS\"] = \"1\"\n",
1819
"import re\n",
1920
"import pandas as pd, numpy as np\n",
2021
"\n",

python/raydp/tf/estimator.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from raydp.spark.interfaces import SparkEstimatorInterface, DF, OPTIONAL_DF
3939
from raydp import stop_spark
4040

41+
4142
class TFEstimator(EstimatorInterface, SparkEstimatorInterface):
4243
def __init__(self,
4344
num_workers: int = 1,
@@ -175,30 +176,21 @@ def train_func(config):
175176
# Model building/compiling need to be within `strategy.scope()`.
176177
multi_worker_model = TFEstimator.build_and_compile_model(config)
177178

178-
# Disable auto-sharding since Ray already handles data distribution
179-
# across workers. Without this, MultiWorkerMirroredStrategy tries to
180-
# re-shard the dataset, producing PerReplica objects that Keras 3.x
181-
# cannot convert back to tensors.
182-
ds_options = tf.data.Options()
183-
ds_options.experimental_distribute.auto_shard_policy = (
184-
tf.data.experimental.AutoShardPolicy.OFF
185-
)
186-
187179
train_dataset = session.get_dataset_shard("train")
188180
train_tf_dataset = train_dataset.to_tf(
189181
feature_columns=config["feature_columns"],
190182
label_columns=config["label_columns"],
191183
batch_size=config["batch_size"],
192184
drop_last=config["drop_last"]
193-
).with_options(ds_options)
185+
)
194186
if config["evaluate"]:
195187
eval_dataset = session.get_dataset_shard("evaluate")
196188
eval_tf_dataset = eval_dataset.to_tf(
197189
feature_columns=config["feature_columns"],
198190
label_columns=config["label_columns"],
199191
batch_size=config["batch_size"],
200192
drop_last=config["drop_last"]
201-
).with_options(ds_options)
193+
)
202194
results = []
203195
callbacks = config["callbacks"]
204196
for _ in range(config["num_epochs"]):

python/setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def run(self):
101101
"pyarrow >= 4.0.1",
102102
"ray >= 2.37.0",
103103
"pyspark >= 4.0.0",
104-
"netifaces",
105104
"protobuf > 3.19.5"
106105
]
107106

@@ -132,6 +131,10 @@ def run(self):
132131
'build_proto_modules': CustomBuildPackageProtos,
133132
},
134133
install_requires=install_requires,
134+
extras_require={
135+
"tensorflow": ["tensorflow>=2.15.1,<2.16"],
136+
"tensorflow-gpu": ["tensorflow[and-cuda]>=2.15.1,<2.16"],
137+
},
135138
setup_requires=["grpcio-tools"],
136139
python_requires='>=3.10',
137140
classifiers=[

0 commit comments

Comments
 (0)