Skip to content

Commit 9860d5c

Browse files
CopilotRouthleck
andcommitted
Improve JointEq error message and add tests for second-order ODEs
Co-authored-by: Routhleck <[email protected]>
1 parent 611a81e commit 9860d5c

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

brainpy/integrators/joint_eq.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,18 @@ def __init__(self, *eqs):
154154
vars, _, _ = _get_args(eq)
155155
for var in vars:
156156
if var in vars_in_eqs:
157-
raise DiffEqError(f'Variable "{var}" has been used, however we got a same '
158-
f'variable name in {eq}. Please change another name.')
157+
raise DiffEqError(
158+
f'Variable "{var}" has been used, however we got a same '
159+
f'variable name in {eq}.\n\n'
160+
f'In JointEq, each state variable should appear as the first parameter '
161+
f'before "t" in exactly one derivative function. If "{var}" is a state '
162+
f'variable in another equation, it should be placed AFTER "t" in this '
163+
f'function as a dependency.\n\n'
164+
f'Correct signature pattern:\n'
165+
f' def d{var}({var}, t, <dependencies>): ... # {var} is the state variable\n'
166+
f' def dOther(other, t, {var}): ... # {var} is a dependency\n\n'
167+
f'Current function signature: {inspect.signature(eq)}'
168+
)
159169
vars_in_eqs.extend(vars)
160170
self.vars_in_eqs.append(vars)
161171

brainpy/integrators/tests/test_joint_eq.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,46 @@ def test_nested_joint_eq1(self):
128128
EQ2 = JointEq(EQ1, dn)
129129
EQ3 = JointEq(EQ2, dV)
130130
print(EQ3(m=0.1, h=0.2, n=0.3, V=10., t=0., I=0.))
131+
132+
def test_second_order_ode(self):
133+
"""Test second-order ODE system (e.g., harmonic oscillator)"""
134+
# Second-order ODE: d²x/dt² = -k*x - c*dx/dt
135+
# Split into: dx/dt = v, dv/dt = -k*x - c*v
136+
k = 1.0 # spring constant
137+
c = 0.1 # damping
138+
139+
def dx(x, t, v):
140+
"""dx/dt = v"""
141+
return v
142+
143+
def dv(v, t, x):
144+
"""dv/dt = -k*x - c*v"""
145+
return -k * x - c * v
146+
147+
# Create joint equation
148+
eq = JointEq(dx, dv)
149+
150+
# Test call
151+
result = eq(x=1.0, v=0.0, t=0.0)
152+
self.assertEqual(len(result), 2)
153+
self.assertEqual(result[0], 0.0) # dx/dt = v = 0
154+
self.assertEqual(result[1], -k * 1.0) # dv/dt = -k*x
155+
156+
def test_second_order_ode_wrong_signature(self):
157+
"""Test that wrong signature gives helpful error message"""
158+
# WRONG: both x and v before t in dx function
159+
def dx_wrong(x, v, t):
160+
return v
161+
162+
def dv(v, t, x):
163+
return -x
164+
165+
# This should raise an error with helpful message
166+
with self.assertRaises(DiffEqError) as cm:
167+
JointEq(dx_wrong, dv)
168+
169+
# Check that error message is helpful
170+
error_msg = str(cm.exception)
171+
self.assertIn('state variable', error_msg.lower())
172+
self.assertIn('AFTER "t"', error_msg)
173+
self.assertIn('dependency', error_msg.lower())

0 commit comments

Comments
 (0)