Skip to content

Commit 8583117

Browse files
committed
add docstring for compile [test]
1 parent d102264 commit 8583117

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

tensordiffeq/models.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,21 @@ def __init__(self, assimilate=False):
1313

1414
def compile(self, layer_sizes, f_model, domain, bcs, isAdaptive=False,
1515
col_weights=None, u_weights=None, g=None, dist=False):
16+
"""
17+
Args:
18+
layer_sizes: A list of layer sizes, can be overwritten via resetting u_model to a keras model
19+
f_model: PDE definition
20+
domain: a Domain object containing the information on the domain of the system
21+
bcs: a list of ICs/BCs for the problem
22+
isAdaptive: Boolean value determining whether to implement self-adaptive solving
23+
col_weights: a tf.Variable vector of collocation point weights for self-adaptive solving
24+
u_weights: a tf.Variable vector of initial boundary weights for self-adaptive training
25+
g: a function in terms of `lambda` for self-adapting solving. Defaults to lambda^2
26+
dist: A boolean value determining whether the solving will be distributed across multiple GPUs
27+
28+
Returns:
29+
None
30+
"""
1631
self.tf_optimizer = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
1732
self.tf_optimizer_weights = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
1833
self.layer_sizes = layer_sizes
@@ -127,13 +142,13 @@ def fit(self, tf_iter=0, newton_iter=0, batch_sz=None, newton_eager=True):
127142
BATCH_SIZE_PER_REPLICA = self.batch_sz
128143
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * self.strategy.num_replicas_in_sync
129144

130-
options = tf.data.Options()
131-
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
145+
# options = tf.data.Options()
146+
# options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
132147

133-
self.train_dataset = tf.data.Dataset.from_tensors(
134-
self.X_f_in) # .batch(GLOBAL_BATCH_SIZE)
148+
self.train_dataset = tf.data.Dataset.from_tensor_slices(
149+
self.X_f_in).batch(GLOBAL_BATCH_SIZE)
135150

136-
self.train_dataset = self.train_dataset.with_options(options)
151+
# self.train_dataset = self.train_dataset.with_options(options)
137152

138153
self.train_dist_dataset = self.strategy.experimental_distribute_dataset(self.train_dataset)
139154

@@ -183,8 +198,8 @@ def predict(self, X_star):
183198
def save(self, path):
184199
self.u_model.save(path)
185200

186-
def load_model(self, path):
187-
self.u_model = tf.keras.models.load_model(path)
201+
def load_model(self, path, compile_model=False):
202+
self.u_model = tf.keras.models.load_model(path, compile=compile_model)
188203

189204
# WIP
190205
# TODO Distributed Discovery Model

0 commit comments

Comments
 (0)