Skip to content

Commit 4fa1899

Browse files
committed
[wpimath] Add Sleipnir Java bindings
1 parent 5c7aaa7 commit 4fa1899

File tree

82 files changed

+12174
-183
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+12174
-183
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright (c) FIRST and other WPILib contributors.
2+
// Open Source Software; you can modify and/or share it under the terms of
3+
// the WPILib BSD license file in the root directory of this project.
4+
5+
package wpilib.robot;
6+
7+
import static org.wpilib.math.autodiff.NumericalIntegration.rk4;
8+
import static org.wpilib.math.autodiff.Variable.cos;
9+
import static org.wpilib.math.autodiff.Variable.sin;
10+
import static org.wpilib.math.autodiff.VariableMatrix.solve;
11+
import static org.wpilib.math.optimization.Constraints.eq;
12+
import static org.wpilib.math.optimization.Constraints.ge;
13+
import static org.wpilib.math.optimization.Constraints.le;
14+
15+
import org.ejml.simple.SimpleMatrix;
16+
import org.wpilib.math.autodiff.Variable;
17+
import org.wpilib.math.autodiff.VariableMatrix;
18+
import org.wpilib.math.optimization.Problem;
19+
import org.wpilib.math.optimization.solver.Options;
20+
import org.wpilib.math.util.MathUtil;
21+
22+
public final class CartPoleBenchmark {
23+
private CartPoleBenchmark() {
24+
// Utility class.
25+
}
26+
27+
@SuppressWarnings("LocalVariableName")
28+
private static VariableMatrix cartPoleDynamics(VariableMatrix x, VariableMatrix u) {
29+
final double m_c = 5.0; // Cart mass (kg)
30+
final double m_p = 0.5; // Pole mass (kg)
31+
final double l = 0.5; // Pole length (m)
32+
final double g = 9.806; // Acceleration due to gravity (m/s²)
33+
34+
var q = x.segment(0, 2);
35+
var qdot = x.segment(2, 2);
36+
var theta = q.get(1);
37+
var thetadot = qdot.get(1);
38+
39+
// [ m_c + m_p m_p l cosθ]
40+
// M(q) = [m_p l cosθ m_p l² ]
41+
var M =
42+
new VariableMatrix(
43+
new Variable[][] {
44+
{new Variable(m_c + m_p), cos(theta).times(m_p * l)},
45+
{cos(theta).times(m_p * l), new Variable(m_p * Math.pow(l, 2))}
46+
});
47+
48+
// [0 −m_p lθ̇ sinθ]
49+
// C(q, q̇) = [0 0 ]
50+
var C =
51+
new VariableMatrix(
52+
new Variable[][] {
53+
{new Variable(0), thetadot.times(-m_p * l).times(sin(theta))},
54+
{new Variable(0), new Variable(0)}
55+
});
56+
57+
// [ 0 ]
58+
// τ_g(q) = [-m_p gl sinθ]
59+
var tau_g =
60+
new VariableMatrix(new Variable[][] {{new Variable(0)}, {sin(theta).times(-m_p * g * l)}});
61+
62+
// [1]
63+
// B = [0]
64+
var B = new VariableMatrix(new double[][] {{1}, {0}});
65+
66+
// q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu)
67+
var qddot = new VariableMatrix(4);
68+
qddot.segment(0, 2).set(qdot);
69+
qddot.segment(2, 2).set(solve(M, tau_g.minus(C.times(qdot)).plus(B.times(u))));
70+
return qddot;
71+
}
72+
73+
/** Cart-pole benchmark. */
74+
public static void cartPole() {
75+
final double T = 5.0; // s
76+
final double dt = 0.05; // s
77+
final int N = (int) (T / dt);
78+
79+
final double u_max = 20.0; // N
80+
final double d_max = 2.0; // m
81+
82+
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}});
83+
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}});
84+
85+
var problem = new Problem();
86+
87+
// x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ
88+
var X = problem.decisionVariable(4, N + 1);
89+
90+
// Initial guess
91+
for (int k = 0; k < N + 1; ++k) {
92+
X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N));
93+
X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N));
94+
}
95+
96+
// u = f_x
97+
var U = problem.decisionVariable(1, N);
98+
99+
// Initial conditions
100+
problem.subjectTo(eq(X.col(0), x_initial));
101+
102+
// Final conditions
103+
problem.subjectTo(eq(X.col(N), x_final));
104+
105+
// Cart position constraints
106+
problem.subjectTo(ge(X.row(0), 0.0));
107+
problem.subjectTo(le(X.row(0), d_max));
108+
109+
// Input constraints
110+
problem.subjectTo(ge(U, -u_max));
111+
problem.subjectTo(le(U, u_max));
112+
113+
// Dynamics constraints - RK4 integration
114+
for (int k = 0; k < N; ++k) {
115+
problem.subjectTo(
116+
eq(X.col(k + 1), rk4(CartPoleBenchmark::cartPoleDynamics, X.col(k), U.col(k), dt)));
117+
}
118+
119+
// Minimize sum squared inputs
120+
var J = new Variable(0.0);
121+
for (int k = 0; k < N; ++k) {
122+
J = J.plus(U.col(k).T().times(U.col(k)).get(0));
123+
}
124+
problem.minimize(J);
125+
126+
problem.solve(new Options().withDiagnostics(true));
127+
}
128+
}

benchmark/src/main/java/wpilib/robot/Main.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ public static void main(String... args) throws RunnerException {
3737
new Runner(opt).run();
3838
}
3939

40+
@Benchmark
41+
@BenchmarkMode(Mode.AverageTime)
42+
@OutputTimeUnit(TimeUnit.MICROSECONDS)
43+
public void cartPole() {
44+
CartPoleBenchmark.cartPole();
45+
}
46+
4047
@Benchmark
4148
@BenchmarkMode(Mode.AverageTime)
4249
@OutputTimeUnit(TimeUnit.MICROSECONDS)

wpimath/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ wpilib_cc_shared_library(
236236
wpilib_jni_java_library(
237237
name = "wpimath-java",
238238
srcs = [":generated_java"] + glob(["src/main/java/**/*.java"]),
239+
javacopts = [
240+
"-Xep:UnicodeInCode:OFF",
241+
],
239242
maven_artifact_name = "wpimath-java",
240243
maven_group_id = "org.wpilib.wpimath",
241244
native_libs = [":wpimathjni"],
@@ -288,6 +291,7 @@ wpilib_java_junit5_test(
288291
"//wpiunits:wpiunits-java",
289292
"//wpiutil:wpiutil-java",
290293
"@maven//:org_ejml_ejml_core",
294+
"@maven//:org_ejml_ejml_ddense",
291295
"@maven//:org_ejml_ejml_simple",
292296
"@maven//:us_hebi_quickbuf_quickbuf_runtime",
293297
],

wpimath/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,18 @@ include(DownloadAndCheck)
77

88
file(
99
GLOB wpimath_jni_src
10-
src/main/native/cpp/jni/ArmFeedforwardJNI.cpp
1110
src/main/native/cpp/jni/DAREJNI.cpp
1211
src/main/native/cpp/jni/EigenJNI.cpp
13-
src/main/native/cpp/jni/Ellipse2dJNI.cpp
1412
src/main/native/cpp/jni/Exceptions.cpp
1513
src/main/native/cpp/jni/StateSpaceUtilJNI.cpp
1614
src/main/native/cpp/jni/Transform3dJNI.cpp
1715
src/main/native/cpp/jni/Twist3dJNI.cpp
16+
src/main/native/cpp/jni/autodiff/GradientJNI.cpp
17+
src/main/native/cpp/jni/autodiff/HessianJNI.cpp
18+
src/main/native/cpp/jni/autodiff/JacobianJNI.cpp
19+
src/main/native/cpp/jni/autodiff/VariableJNI.cpp
20+
src/main/native/cpp/jni/autodiff/VariableMatrixJNI.cpp
21+
src/main/native/cpp/jni/optimization/ProblemJNI.cpp
1822
)
1923

2024
# Java bindings
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) FIRST and other WPILib contributors.
2+
// Open Source Software; you can modify and/or share it under the terms of
3+
// the WPILib BSD license file in the root directory of this project.
4+
5+
package org.wpilib.math.autodiff;
6+
7+
/**
8+
* Expression type.
9+
*
10+
* <p>Used for autodiff caching.
11+
*/
12+
public enum ExpressionType {
13+
/** There is no expression. */
14+
NONE(0),
15+
/** The expression is a constant. */
16+
CONSTANT(1),
17+
/** The expression is composed of linear and lower-order operators. */
18+
LINEAR(2),
19+
/** The expression is composed of quadratic and lower-order operators. */
20+
QUADRATIC(3),
21+
/** The expression is composed of nonlinear and lower-order operators. */
22+
NONLINEAR(4);
23+
24+
/** ExpressionType value. */
25+
public final int value;
26+
27+
ExpressionType(int value) {
28+
this.value = value;
29+
}
30+
31+
/**
32+
* Converts integer to its corresponding enum value.
33+
*
34+
* @param x The integer.
35+
* @return The enum value.
36+
*/
37+
public static ExpressionType fromInt(int x) {
38+
return switch (x) {
39+
case 0 -> ExpressionType.NONE;
40+
case 1 -> ExpressionType.CONSTANT;
41+
case 2 -> ExpressionType.LINEAR;
42+
case 3 -> ExpressionType.QUADRATIC;
43+
case 4 -> ExpressionType.NONLINEAR;
44+
default -> null;
45+
};
46+
}
47+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (c) FIRST and other WPILib contributors.
2+
// Open Source Software; you can modify and/or share it under the terms of
3+
// the WPILib BSD license file in the root directory of this project.
4+
5+
package org.wpilib.math.autodiff;
6+
7+
import org.ejml.simple.SimpleMatrix;
8+
9+
/**
10+
* This class calculates the gradient of a variable with respect to a vector of variables.
11+
*
12+
* <p>The gradient is only recomputed if the variable expression is quadratic or higher order.
13+
*/
14+
public class Gradient implements AutoCloseable {
15+
private long m_handle;
16+
private int m_rows;
17+
18+
/**
19+
* Constructs a Gradient object.
20+
*
21+
* @param variable Variable of which to compute the gradient.
22+
* @param wrt Variable with respect to which to compute the gradient.
23+
*/
24+
public Gradient(Variable variable, Variable wrt) {
25+
this(variable, new VariableMatrix(wrt));
26+
}
27+
28+
/**
29+
* Constructs a Gradient object.
30+
*
31+
* @param variable Variable of which to compute the gradient.
32+
* @param wrt Vector of variables with respect to which to compute the gradient.
33+
*/
34+
public Gradient(Variable variable, VariableMatrix wrt) {
35+
assert wrt.cols() == 1;
36+
37+
m_handle = GradientJNI.create(variable.getHandle(), wrt.getHandles());
38+
m_rows = wrt.rows();
39+
}
40+
41+
/**
42+
* Constructs a Gradient object.
43+
*
44+
* @param variable Variable of which to compute the gradient.
45+
* @param wrt Vector of variables with respect to which to compute the gradient.
46+
*/
47+
public Gradient(Variable variable, VariableBlock wrt) {
48+
this(variable, new VariableMatrix(wrt));
49+
}
50+
51+
@Override
52+
public void close() {
53+
if (m_handle != 0) {
54+
GradientJNI.destroy(m_handle);
55+
m_handle = 0;
56+
}
57+
}
58+
59+
/**
60+
* Returns the gradient as a VariableMatrix.
61+
*
62+
* <p>This is useful when constructing optimization problems with derivatives in them.
63+
*
64+
* @return The gradient as a VariableMatrix.
65+
*/
66+
public VariableMatrix get() {
67+
return new VariableMatrix(m_rows, 1, GradientJNI.get(m_handle));
68+
}
69+
70+
/**
71+
* Evaluates the gradient at wrt's value.
72+
*
73+
* @return The gradient at wrt's value.
74+
*/
75+
public SimpleMatrix value() {
76+
return GradientJNI.value(m_handle).toSimpleMatrix(m_rows, 1);
77+
}
78+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) FIRST and other WPILib contributors.
2+
// Open Source Software; you can modify and/or share it under the terms of
3+
// the WPILib BSD license file in the root directory of this project.
4+
5+
package org.wpilib.math.autodiff;
6+
7+
import org.wpilib.math.jni.WPIMathJNI;
8+
9+
/** Gradient JNI functions. */
10+
final class GradientJNI extends WPIMathJNI {
11+
private GradientJNI() {
12+
// Utility class.
13+
}
14+
15+
/**
16+
* Constructs a Gradient object.
17+
*
18+
* @param variable Variable of which to compute the Gradient.
19+
* @param wrt Vector of variables with respect to which to compute the Gradient.
20+
*/
21+
static native long create(long variable, long[] wrt);
22+
23+
/**
24+
* Destructs a Gradient.
25+
*
26+
* @param handle Gradient handle.
27+
*/
28+
static native void destroy(long handle);
29+
30+
/**
31+
* Returns the Gradient as an array of Variable handles.
32+
*
33+
* <p>This is useful when constructing optimization problems with derivatives in them.
34+
*
35+
* @param handle Gradient handle.
36+
* @return The Gradient as an array of Variable handles.
37+
*/
38+
static native long[] get(long handle);
39+
40+
/**
41+
* Evaluates the Gradient at wrt's value.
42+
*
43+
* @param handle Gradient handle.
44+
* @return A record containing the triplet row, column and value arrays (int[], int[], and
45+
* double[] respectively).
46+
*/
47+
static native NativeSparseTriplets value(long handle);
48+
}

0 commit comments

Comments
 (0)