diff --git a/src/shared/algorithms/Kalman/Kalman.h b/src/shared/algorithms/Kalman/Kalman.h
index 8673baf8eaa5de97e29bc29511fb9430408d7292..5d663be1745ee1a0d409d22bdfd797e084f89aec 100644
--- a/src/shared/algorithms/Kalman/Kalman.h
+++ b/src/shared/algorithms/Kalman/Kalman.h
@@ -33,7 +33,7 @@ namespace Boardcore
* This class uses templates in order to know the size of each matrix at
* compile-time. This way we avoid Eigen to allocate memory dynamically.
*/
-template <typename T, int N_size, int P_size>
+template <typename T, int N_size, int P_size, int M_size = 1>
class Kalman
{
public:
@@ -41,8 +41,10 @@ public:
using MatrixPN = Eigen::Matrix<T, P_size, N_size>;
using MatrixNP = Eigen::Matrix<T, N_size, P_size>;
using MatrixPP = Eigen::Matrix<T, P_size, P_size>;
+ using MatrixNM = Eigen::Matrix<T, N_size, M_size>;
using CVectorN = Eigen::Vector<T, N_size>;
using CVectorP = Eigen::Vector<T, P_size>;
+ using CVectorM = Eigen::Vector<T, M_size>;
/**
* @brief Configuration struct for the Kalman class.
@@ -54,6 +56,7 @@ public:
MatrixNN Q;
MatrixPP R;
MatrixNN P;
+ MatrixNM G;
CVectorN x;
};
@@ -64,8 +67,8 @@ public:
*/
Kalman(const KalmanConfig& config)
: F(config.F), H(config.H), Q(config.Q), R(config.R), P(config.P),
- S(MatrixPP::Zero(P_size, P_size)), K(MatrixNP::Zero(N_size, P_size)),
- x(config.x)
+ G(config.G), S(MatrixPP::Zero(P_size, P_size)),
+ K(MatrixNP::Zero(N_size, P_size)), x(config.x)
{
I.setIdentity();
}
@@ -77,6 +80,7 @@ public:
Q = config.Q;
R = config.R;
P = config.P;
+ G = config.G;
S = MatrixPP::Zero(P_size, P_size);
K = MatrixNP::Zero(N_size, P_size);
x = config.x;
@@ -96,12 +100,35 @@ public:
*
* @param F_new updated F matrix.
*/
- void predict(const MatrixNN& F_new)
+ void predictUpdateF(const MatrixNN& F_new)
{
F = F_new;
predict();
}
+ /**
+ * @brief Prediction step with previous F matrix and with the control
+ * vector.
+ */
+ void predictWithControl(const CVectorM& control)
+ {
+ x = F * x + G * control;
+ P = F * P * F.transpose() + Q;
+ }
+
+ /**
+ * @brief Prediction step.
+ *
+ * @param F_new updated F matrix.
+ * @param control Control vector.
+ */
+ void predictWithControlUpdateF(const MatrixNN& F_new,
+ const CVectorM& control)
+ {
+ F = F_new;
+ predictWithControl(control);
+ }
+
/**
* @brief Correction step.
*
@@ -169,6 +196,7 @@ private:
MatrixNN Q; /**< Model variance matrix (n x n) */
MatrixPP R; /**< Measurement variance (p x p) */
MatrixNN P; /**< Error covariance matrix (n x n) */
+ MatrixNM G; /**< Input matrix (n x m) */
MatrixPP S;
MatrixNP K; /**< kalman gain */
diff --git a/src/tests/algorithms/Kalman/test-kalman-benchmark.cpp b/src/tests/algorithms/Kalman/test-kalman-benchmark.cpp
index 52ca7e81c5ea91644604209d62d87c177f02ca6f..4d22346370c1703ba62009d9f9121f75eedb7328 100644
--- a/src/tests/algorithms/Kalman/test-kalman-benchmark.cpp
+++ b/src/tests/algorithms/Kalman/test-kalman-benchmark.cpp
@@ -63,6 +63,8 @@ int main()
// Measurement variance
Matrix<float, p, p> R{10};
+ Matrix<float, n, 1> G = Matrix<float, n, 1>::Zero();
+
Matrix<float, n, 1> x0(INPUT[0], 0.0, 0.0);
Matrix<float, p, 1> y(p); // vector with p elements (only one in this case)
@@ -73,6 +75,7 @@ int main()
config.Q = Q;
config.R = R;
config.P = P;
+ config.G = G;
config.x = x0;
Kalman<float, n, p> filter(config);
@@ -98,7 +101,7 @@ int main()
y(0) = INPUT[i];
- filter.predict(F);
+ filter.predictUpdateF(F);
if (!filter.correct(y))
printf("Correction failed at iteration : %u \n", i);
diff --git a/src/tests/catch/test-kalman.cpp b/src/tests/catch/test-kalman.cpp
index 71974bebedf31ff6790e7cc7104c9c76b5fd2538..e44e3fb2eee1f51f3f4534226093579348a651f2 100644
--- a/src/tests/catch/test-kalman.cpp
+++ b/src/tests/catch/test-kalman.cpp
@@ -58,6 +58,8 @@ static const Matrix<float, STATES_DIM, STATES_DIM> Q =
.finished();
// Measurement variance
static const Matrix<float, OUTPUTS_DIM, OUTPUTS_DIM> R{10};
+static const Kalman<float, STATES_DIM, OUTPUTS_DIM>::MatrixNM G =
+ Kalman<float, STATES_DIM, OUTPUTS_DIM>::MatrixNM::Zero();
// State vector
static const Matrix<float, STATES_DIM, 1> x0(INPUT[0], 0.0, 0.0);
@@ -70,6 +72,7 @@ getKalmanConfig()
config.Q = Q;
config.R = R;
config.P = P;
+ config.G = G;
config.x = x0;
return config;
@@ -95,7 +98,7 @@ TEST_CASE("Update test")
F_new(0, 2) = 0.5 * T * T;
F_new(1, 2) = T;
- filter.predict(F_new);
+ filter.predictUpdateF(F_new);
if (!filter.correct(y))
{