|
38 | 38 | from raydp.spark.interfaces import SparkEstimatorInterface, DF, OPTIONAL_DF |
39 | 39 | from raydp import stop_spark |
40 | 40 |
|
| 41 | + |
41 | 42 | class TFEstimator(EstimatorInterface, SparkEstimatorInterface): |
42 | 43 | def __init__(self, |
43 | 44 | num_workers: int = 1, |
@@ -175,30 +176,21 @@ def train_func(config): |
175 | 176 | # Model building/compiling need to be within `strategy.scope()`. |
176 | 177 | multi_worker_model = TFEstimator.build_and_compile_model(config) |
177 | 178 |
|
178 | | - # Disable auto-sharding since Ray already handles data distribution |
179 | | - # across workers. Without this, MultiWorkerMirroredStrategy tries to |
180 | | - # re-shard the dataset, producing PerReplica objects that Keras 3.x |
181 | | - # cannot convert back to tensors. |
182 | | - ds_options = tf.data.Options() |
183 | | - ds_options.experimental_distribute.auto_shard_policy = ( |
184 | | - tf.data.experimental.AutoShardPolicy.OFF |
185 | | - ) |
186 | | - |
187 | 179 | train_dataset = session.get_dataset_shard("train") |
188 | 180 | train_tf_dataset = train_dataset.to_tf( |
189 | 181 | feature_columns=config["feature_columns"], |
190 | 182 | label_columns=config["label_columns"], |
191 | 183 | batch_size=config["batch_size"], |
192 | 184 | drop_last=config["drop_last"] |
193 | | - ).with_options(ds_options) |
| 185 | + ) |
194 | 186 | if config["evaluate"]: |
195 | 187 | eval_dataset = session.get_dataset_shard("evaluate") |
196 | 188 | eval_tf_dataset = eval_dataset.to_tf( |
197 | 189 | feature_columns=config["feature_columns"], |
198 | 190 | label_columns=config["label_columns"], |
199 | 191 | batch_size=config["batch_size"], |
200 | 192 | drop_last=config["drop_last"] |
201 | | - ).with_options(ds_options) |
| 193 | + ) |
202 | 194 | results = [] |
203 | 195 | callbacks = config["callbacks"] |
204 | 196 | for _ in range(config["num_epochs"]): |
|
0 commit comments