@@ -13,6 +13,21 @@ def __init__(self, assimilate=False):
1313
1414 def compile (self , layer_sizes , f_model , domain , bcs , isAdaptive = False ,
1515 col_weights = None , u_weights = None , g = None , dist = False ):
16+ """
17+ Args:
18+ layer_sizes: A list of layer sizes, can be overwritten via resetting u_model to a keras model
19+ f_model: PDE definition
20+ domain: a Domain object containing the information on the domain of the system
21+ bcs: a list of ICs/BCs for the problem
22+ isAdaptive: Boolean value determining whether to implement self-adaptive solving
23+ col_weights: a tf.Variable vector of collocation point weights for self-adaptive solving
24+ u_weights: a tf.Variable vector of initial boundary weights for self-adaptive training
25+ g: a function in terms of `lambda` for self-adapting solving. Defaults to lambda^2
26+ dist: A boolean value determining whether the solving will be distributed across multiple GPUs
27+
28+ Returns:
29+ None
30+ """
1631 self .tf_optimizer = tf .keras .optimizers .Adam (lr = 0.005 , beta_1 = .99 )
1732 self .tf_optimizer_weights = tf .keras .optimizers .Adam (lr = 0.005 , beta_1 = .99 )
1833 self .layer_sizes = layer_sizes
@@ -127,13 +142,13 @@ def fit(self, tf_iter=0, newton_iter=0, batch_sz=None, newton_eager=True):
127142 BATCH_SIZE_PER_REPLICA = self .batch_sz
128143 GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * self .strategy .num_replicas_in_sync
129144
130- options = tf .data .Options ()
131- options .experimental_distribute .auto_shard_policy = tf .data .experimental .AutoShardPolicy .DATA
145+ # options = tf.data.Options()
146+ # options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
132147
133- self .train_dataset = tf .data .Dataset .from_tensors (
134- self .X_f_in ) # .batch(GLOBAL_BATCH_SIZE)
148+ self .train_dataset = tf .data .Dataset .from_tensor_slices (
149+ self .X_f_in ).batch (GLOBAL_BATCH_SIZE )
135150
136- self .train_dataset = self .train_dataset .with_options (options )
151+ # self.train_dataset = self.train_dataset.with_options(options)
137152
138153 self .train_dist_dataset = self .strategy .experimental_distribute_dataset (self .train_dataset )
139154
@@ -183,8 +198,8 @@ def predict(self, X_star):
183198 def save (self , path ):
184199 self .u_model .save (path )
185200
186- def load_model (self , path ):
187- self .u_model = tf .keras .models .load_model (path )
201+ def load_model (self , path , compile_model = False ):
202+ self .u_model = tf .keras .models .load_model (path , compile = compile_model )
188203
189204# WIP
190205# TODO Distributed Discovery Model
0 commit comments