@@ -27,9 +27,9 @@ def generate_spiral_data(n_points=100, noise=0.05):
2727 Returns:
2828 Tuple of (t, trajectory) where t are time points and trajectory is the path
2929 """
30- t = np .linspace (0 , 2 * np .pi , n_points )
31- x0 = np .cos (t ) * np .exp (- 0.1 * t ) + noise * np .random .randn (n_points )
32- x1 = np .sin (t ) * np .exp (- 0.1 * t ) + noise * np .random .randn (n_points )
30+ t = np .linspace (0 , 2 * np .pi , n_points )
31+ x0 = np .cos (t ) * np .exp (- 0.1 * t ) + noise * np .random .randn (n_points )
32+ x1 = np .sin (t ) * np .exp (- 0.1 * t ) + noise * np .random .randn (n_points )
3333
3434 trajectory = np .stack ([x0 , x1 ], axis = 1 )
3535 return t , trajectory
@@ -38,67 +38,68 @@ def generate_spiral_data(n_points=100, noise=0.05):
3838class SimpleNeuralODE :
3939 """
4040 A simplified Neural ODE implementation using Euler's method.
41-
41+
4242 This demonstrates the core concept: learning the derivative function
4343 that governs the dynamics of a system.
4444 """
45-
45+
4646 def __init__ (self , ode_func , dt = 0.01 ):
4747 """
4848 Initialize the Neural ODE.
49-
49+
5050 Args:
5151 ode_func: Neural network that learns dx/dt = f(x, t)
5252 dt: Time step for numerical integration
5353 """
5454 self .ode_func = ode_func
5555 self .dt = dt
56-
56+
5757 def __call__ (self , x0 , n_steps ):
5858 """
5959 Integrate the ODE forward in time.
60-
60+
6161 Args:
6262 x0: Initial condition (batch_size, state_dim)
6363 n_steps: Number of integration steps
64-
64+
6565 Returns:
6666 Trajectory of states
6767 """
6868 trajectory = [x0 ]
6969 x = x0
70-
70+
7171 for _ in range (n_steps ):
7272 # Compute derivative: dx/dt = f(x)
7373 dx_dt = self .ode_func (x )
74-
74+
7575 # Euler step: x_{t+1} = x_t + dt * dx/dt
7676 x_next_data = x .data + self .dt * dx_dt .data
7777 x_next = Tensor (x_next_data , requires_grad = x .requires_grad )
78-
78+
7979 # Set up backward pass
8080 if x .requires_grad :
81+
8182 def _backward (x_curr = x , dx_dt_curr = dx_dt , x_next_curr = x_next ):
8283 if x_next_curr .grad is not None :
8384 # Gradient flows back through Euler step
8485 if x_curr .grad is None :
8586 x_curr .grad = x_next_curr .grad .copy ()
8687 else :
8788 x_curr .grad += x_next_curr .grad
88-
89+
8990 # Gradient w.r.t. derivative
9091 dx_dt_grad = self .dt * x_next_curr .grad
9192 if dx_dt_curr .grad is None :
9293 dx_dt_curr .grad = dx_dt_grad
9394 else :
9495 dx_dt_curr .grad += dx_dt_grad
95-
96+
9697 x_next ._backward = _backward
9798 x_next ._prev = {x , dx_dt }
98-
99+
99100 trajectory .append (x_next )
100101 x = x_next
101-
102+
102103 return trajectory
103104
104105
@@ -108,89 +109,83 @@ def run_neural_ode_example():
108109 """
109110 print ("Neural ODE Example: Learning Spiral Dynamics" )
110111 print ("=" * 50 )
111-
112+
112113 # Generate spiral data
113114 t , true_trajectory = generate_spiral_data (n_points = 50 , noise = 0.02 )
114-
115+
115116 print (f"Generated spiral data with { len (true_trajectory )} points" )
116-
117+
117118 # Create Neural ODE function
118119 # This network learns dx/dt = f(x)
119- ode_func = Sequential (
120- Linear (2 , 16 ),
121- Tanh (),
122- Linear (16 , 16 ),
123- Tanh (),
124- Linear (16 , 2 )
125- )
126-
120+ ode_func = Sequential (Linear (2 , 16 ), Tanh (), Linear (16 , 16 ), Tanh (), Linear (16 , 2 ))
121+
127122 # Create Neural ODE solver
128123 neural_ode = SimpleNeuralODE (ode_func , dt = 0.05 )
129-
124+
130125 # Prepare training data
131126 # Use consecutive points: x[i] -> x[i+1]
132127 X_data = true_trajectory [:- 1 ] # Current states
133- y_data = true_trajectory [1 :] # Next states
134-
128+ y_data = true_trajectory [1 :] # Next states
129+
135130 X_tensor = Tensor (X_data , requires_grad = True )
136131 y_tensor = Tensor (y_data , requires_grad = False )
137-
132+
138133 # Loss function and optimizer
139134 loss_fn = MSELoss ()
140135 optimizer = Adam (ode_func .parameters (), lr = 0.001 )
141-
136+
142137 # Training loop
143138 epochs = 500
144139 losses = []
145-
140+
146141 print ("\n Training Neural ODE..." )
147-
142+
148143 for epoch in range (epochs ):
149144 # Zero gradients
150145 optimizer .zero_grad ()
151-
146+
152147 # Forward pass: predict next states
153148 predicted_trajectory = []
154149 for i in range (len (X_data )):
155- x_current = Tensor (X_data [i : i + 1 ], requires_grad = True )
150+ x_current = Tensor (X_data [i : i + 1 ], requires_grad = True )
156151 trajectory = neural_ode (x_current , n_steps = 1 )
157152 predicted_trajectory .append (trajectory [1 ]) # Next state
158-
153+
159154 # Stack predictions
160155 predicted_next = Tensor (
161156 np .array ([pred .data [0 ] for pred in predicted_trajectory ]),
162- requires_grad = True
157+ requires_grad = True ,
163158 )
164-
159+
165160 # Compute loss
166161 loss = loss_fn (predicted_next , y_tensor )
167-
162+
168163 # Backward pass
169164 loss .backward ()
170-
165+
171166 # Update parameters
172167 optimizer .step ()
173-
168+
174169 losses .append (loss .data )
175-
170+
176171 if epoch % 50 == 0 :
177172 print (f"Epoch { epoch } : Loss = { loss .data :.6f} " )
178-
173+
179174 print ("Training completed!" )
180-
175+
181176 # Generate predictions for longer trajectory
182177 print ("\n Generating predictions..." )
183-
178+
184179 # Start from initial condition
185180 x0 = Tensor (true_trajectory [0 :1 ], requires_grad = False )
186- predicted_trajectory = neural_ode (x0 , n_steps = len (true_trajectory )- 1 )
187-
181+ predicted_trajectory = neural_ode (x0 , n_steps = len (true_trajectory ) - 1 )
182+
188183 # Extract data for plotting
189184 predicted_data = np .array ([state .data [0 ] for state in predicted_trajectory ])
190-
185+
191186 # Plot results
192187 plot_results (true_trajectory , predicted_data , losses )
193-
188+
194189 return True
195190
196191
@@ -200,55 +195,73 @@ def plot_results(true_trajectory, predicted_trajectory, losses):
200195 """
201196 try :
202197 plt .figure (figsize = (15 , 5 ))
203-
198+
204199 # Plot trajectories
205200 plt .subplot (1 , 3 , 1 )
206- plt .plot (true_trajectory [:, 0 ], true_trajectory [:, 1 ], 'b-' , label = 'True' , linewidth = 2 )
207- plt .plot (predicted_trajectory [:, 0 ], predicted_trajectory [:, 1 ], 'r--' , label = 'Predicted' , linewidth = 2 )
208- plt .scatter (true_trajectory [0 , 0 ], true_trajectory [0 , 1 ], c = 'green' , s = 100 , label = 'Start' )
209- plt .title ('Trajectory Comparison' )
210- plt .xlabel ('X0' )
211- plt .ylabel ('X1' )
201+ plt .plot (
202+ true_trajectory [:, 0 ],
203+ true_trajectory [:, 1 ],
204+ "b-" ,
205+ label = "True" ,
206+ linewidth = 2 ,
207+ )
208+ plt .plot (
209+ predicted_trajectory [:, 0 ],
210+ predicted_trajectory [:, 1 ],
211+ "r--" ,
212+ label = "Predicted" ,
213+ linewidth = 2 ,
214+ )
215+ plt .scatter (
216+ true_trajectory [0 , 0 ],
217+ true_trajectory [0 , 1 ],
218+ c = "green" ,
219+ s = 100 ,
220+ label = "Start" ,
221+ )
222+ plt .title ("Trajectory Comparison" )
223+ plt .xlabel ("X0" )
224+ plt .ylabel ("X1" )
212225 plt .legend ()
213226 plt .grid (True )
214- plt .axis (' equal' )
215-
227+ plt .axis (" equal" )
228+
216229 # Plot loss curve
217230 plt .subplot (1 , 3 , 2 )
218231 plt .plot (losses )
219- plt .title (' Training Loss' )
220- plt .xlabel (' Epoch' )
221- plt .ylabel (' Loss' )
222- plt .yscale (' log' )
232+ plt .title (" Training Loss" )
233+ plt .xlabel (" Epoch" )
234+ plt .ylabel (" Loss" )
235+ plt .yscale (" log" )
223236 plt .grid (True )
224-
237+
225238 # Plot error over time
226239 plt .subplot (1 , 3 , 3 )
227240 errors = np .linalg .norm (true_trajectory - predicted_trajectory , axis = 1 )
228241 plt .plot (errors )
229- plt .title (' Prediction Error' )
230- plt .xlabel (' Time Step' )
231- plt .ylabel (' L2 Error' )
242+ plt .title (" Prediction Error" )
243+ plt .xlabel (" Time Step" )
244+ plt .ylabel (" L2 Error" )
232245 plt .grid (True )
233-
246+
234247 plt .tight_layout ()
235- plt .savefig (' neural_ode_results.png' , dpi = 150 , bbox_inches = ' tight' )
248+ plt .savefig (" neural_ode_results.png" , dpi = 150 , bbox_inches = " tight" )
236249 plt .show ()
237-
250+
238251 print ("Results saved as 'neural_ode_results.png'" )
239-
252+
240253 # Print final statistics
241254 final_error = np .mean (errors )
242255 print (f"\n Final average prediction error: { final_error :.4f} " )
243-
256+
244257 except Exception as e :
245258 print (f"Could not generate plot: { e } " )
246259
247260
248261if __name__ == "__main__" :
249262 print ("Simplified Neural ODE Example" )
250263 print ("=" * 30 )
251-
264+
252265 try :
253266 success = run_neural_ode_example ()
254267 if success :
@@ -258,4 +271,5 @@ def plot_results(true_trajectory, predicted_trajectory, losses):
258271 except Exception as e :
259272 print (f"\n ❌ Error running Neural ODE example: { e } " )
260273 import traceback
261- traceback .print_exc ()
274+
275+ traceback .print_exc ()
0 commit comments