@@ -11,22 +11,17 @@ def func(x):
1111 return x * np .sin (5 * x )
1212
1313
14- def main ():
15- geom = dde .geometry .Interval (- 1 , 1 )
16- num_train = 16
17- num_test = 100
18- data = dde .data .Function (geom , func , num_train , num_test )
14+ geom = dde .geometry .Interval (- 1 , 1 )
15+ num_train = 16
16+ num_test = 100
17+ data = dde .data .Function (geom , func , num_train , num_test )
1918
20- activation = "tanh"
21- initializer = "Glorot uniform"
22- net = dde .maps .FNN ([1 ] + [20 ] * 3 + [1 ], activation , initializer )
19+ activation = "tanh"
20+ initializer = "Glorot uniform"
21+ net = dde .maps .FNN ([1 ] + [20 ] * 3 + [1 ], activation , initializer )
2322
24- model = dde .Model (data , net )
25- model .compile ("adam" , lr = 0.001 , metrics = ["l2 relative error" ])
26- losshistory , train_state = model .train (epochs = 10000 )
23+ model = dde .Model (data , net )
24+ model .compile ("adam" , lr = 0.001 , metrics = ["l2 relative error" ])
25+ losshistory , train_state = model .train (epochs = 10000 )
2726
28- dde .saveplot (losshistory , train_state , issave = True , isplot = True )
29-
30-
31- if __name__ == "__main__" :
32- main ()
27+ dde .saveplot (losshistory , train_state , issave = True , isplot = True )
0 commit comments