kalman filter code

import torch import torch.nn as nn class KalmanFilter(nn.Module): def __init__(self, state_dim, measurement_dim): """ Initialize Kalman Filter Args: state_dim (int): Dimension of state vector measurement_dim (int): Dimension of measurement vector """ super(KalmanFilter, self).__init__() self.state_dim = state_dim self.measurement_dim = measurement_dim # Initialize state transition matrix (F) self.F = nn.Parameter(torch.eye(state_dim), requires_grad=False) # Initialize measurement matrix (H) self.H = nn.Parameter(torch.zeros(measurement_dim, state_dim), requires_grad=False) self.H.data[:, :measurement_dim] = torch.eye(measurement_dim) # Initialize process noise covariance (Q) self.Q = nn.Parameter(torch.eye(state_dim) * 0.1, requires_grad=False) # Initialize measurement noise covariance (R) self.R = nn.Parameter(torch.eye(measurement_dim) * 1.0, requires_grad=False) # Initialize state estimate and covariance self.x = torch.zeros(state_dim, 1) self.P = torch.eye(state_dim) def predict(self): """ Prediction step of Kalman Filter Returns: tuple: (predicted_state, predicted_covariance) """ # Predict state estimate self.x = torch.matmul(self.F, self.x) # Predict error covariance self.P = torch.matmul(torch.matmul(self.F, self.P), self.F.t()) + self.Q return self.x, self.P def update(self, measurement): """ Update step of Kalman Filter Args: measurement (torch.Tensor): Measurement vector Returns: tuple: (updated_state, updated_covariance) """ # Reshape measurement if needed if measurement.dim() == 1: measurement = measurement.unsqueeze(1) # Calculate Kalman gain S = torch.matmul(torch.matmul(self.H, self.P), self.H.t()) + self.R K = torch.matmul(torch.matmul(self.P, self.H.t()), torch.inverse(S)) # Update state estimate innovation = measurement - torch.matmul(self.H, self.x) self.x = self.x + torch.matmul(K, innovation) # Update error covariance I = torch.eye(self.state_dim) self.P = torch.matmul((I - torch.matmul(K, self.H)), self.P) return self.x, self.P def set_state_transition(self, F): """Set state transition matrix""" self.F.data = F def set_measurement_matrix(self, H): """Set measurement matrix""" self.H.data = H def set_process_noise(self, Q): """Set process noise covariance""" self.Q.data = Q def set_measurement_noise(self, R): """Set measurement noise covariance""" self.R.data = R def reset(self): """Reset state estimate and covariance""" self.x = torch.zeros(self.state_dim, 1) self.P = torch.eye(self.state_dim) # Example usage def demo_kalman_filter(): # Create a simple tracking problem kf = KalmanFilter(state_dim=4, measurement_dim=2) # Set up state transition matrix for constant velocity model # State vector: [x, y, vx, vy] dt = 0.1 F = torch.tensor([ [1, 0, dt, 0], [0, 1, 0, dt], [0, 0, 1, 0], [0, 0, 0, 1] ], dtype=torch.float32) kf.set_state_transition(F) # Generate some noisy measurements true_x = torch.tensor([0.0, 0.0, 1.0, 1.0]).unsqueeze(1) # True state measurements = [] true_states = [] for _ in range(10): # Generate true state true_x = torch.matmul(F, true_x) true_states.append(true_x.clone()) # Generate noisy measurement (only position, not velocity) noise = torch.randn(2, 1) * 0.1 measurement = torch.matmul(kf.H, true_x) + noise measurements.append(measurement) # Kalman filter steps kf.predict() state_estimate, _ = kf.update(measurement) print(f"True position: ({true_x[0,0]:.2f}, {true_x[1,0]:.2f})") print(f"Measured position: ({measurement[0,0]:.2f}, {measurement[1,0]:.2f})") print(f"Estimated position: ({state_estimate[0,0]:.2f}, {state_estimate[1,0]:.2f})\n")

Public Last updated: 2024-10-27 01:09:34 AM