kalman filter code
import torch
import torch.nn as nn
class KalmanFilter(nn.Module):
def __init__(self, state_dim, measurement_dim):
"""
Initialize Kalman Filter
Args:
state_dim (int): Dimension of state vector
measurement_dim (int): Dimension of measurement vector
"""
super(KalmanFilter, self).__init__()
self.state_dim = state_dim
self.measurement_dim = measurement_dim
# Initialize state transition matrix (F)
self.F = nn.Parameter(torch.eye(state_dim), requires_grad=False)
# Initialize measurement matrix (H)
self.H = nn.Parameter(torch.zeros(measurement_dim, state_dim), requires_grad=False)
self.H.data[:, :measurement_dim] = torch.eye(measurement_dim)
# Initialize process noise covariance (Q)
self.Q = nn.Parameter(torch.eye(state_dim) * 0.1, requires_grad=False)
# Initialize measurement noise covariance (R)
self.R = nn.Parameter(torch.eye(measurement_dim) * 1.0, requires_grad=False)
# Initialize state estimate and covariance
self.x = torch.zeros(state_dim, 1)
self.P = torch.eye(state_dim)
def predict(self):
"""
Prediction step of Kalman Filter
Returns:
tuple: (predicted_state, predicted_covariance)
"""
# Predict state estimate
self.x = torch.matmul(self.F, self.x)
# Predict error covariance
self.P = torch.matmul(torch.matmul(self.F, self.P), self.F.t()) + self.Q
return self.x, self.P
def update(self, measurement):
"""
Update step of Kalman Filter
Args:
measurement (torch.Tensor): Measurement vector
Returns:
tuple: (updated_state, updated_covariance)
"""
# Reshape measurement if needed
if measurement.dim() == 1:
measurement = measurement.unsqueeze(1)
# Calculate Kalman gain
S = torch.matmul(torch.matmul(self.H, self.P), self.H.t()) + self.R
K = torch.matmul(torch.matmul(self.P, self.H.t()), torch.inverse(S))
# Update state estimate
innovation = measurement - torch.matmul(self.H, self.x)
self.x = self.x + torch.matmul(K, innovation)
# Update error covariance
I = torch.eye(self.state_dim)
self.P = torch.matmul((I - torch.matmul(K, self.H)), self.P)
return self.x, self.P
def set_state_transition(self, F):
"""Set state transition matrix"""
self.F.data = F
def set_measurement_matrix(self, H):
"""Set measurement matrix"""
self.H.data = H
def set_process_noise(self, Q):
"""Set process noise covariance"""
self.Q.data = Q
def set_measurement_noise(self, R):
"""Set measurement noise covariance"""
self.R.data = R
def reset(self):
"""Reset state estimate and covariance"""
self.x = torch.zeros(self.state_dim, 1)
self.P = torch.eye(self.state_dim)
# Example usage
def demo_kalman_filter():
# Create a simple tracking problem
kf = KalmanFilter(state_dim=4, measurement_dim=2)
# Set up state transition matrix for constant velocity model
# State vector: [x, y, vx, vy]
dt = 0.1
F = torch.tensor([
[1, 0, dt, 0],
[0, 1, 0, dt],
[0, 0, 1, 0],
[0, 0, 0, 1]
], dtype=torch.float32)
kf.set_state_transition(F)
# Generate some noisy measurements
true_x = torch.tensor([0.0, 0.0, 1.0, 1.0]).unsqueeze(1) # True state
measurements = []
true_states = []
for _ in range(10):
# Generate true state
true_x = torch.matmul(F, true_x)
true_states.append(true_x.clone())
# Generate noisy measurement (only position, not velocity)
noise = torch.randn(2, 1) * 0.1
measurement = torch.matmul(kf.H, true_x) + noise
measurements.append(measurement)
# Kalman filter steps
kf.predict()
state_estimate, _ = kf.update(measurement)
print(f"True position: ({true_x[0,0]:.2f}, {true_x[1,0]:.2f})")
print(f"Measured position: ({measurement[0,0]:.2f}, {measurement[1,0]:.2f})")
print(f"Estimated position: ({state_estimate[0,0]:.2f}, {state_estimate[1,0]:.2f})\n")
Public Last updated: 2024-10-27 01:09:34 AM