Head Stabilization¶
- class flygym.examples.head_stabilization.JointAngleScaler¶
A class for standardizing joint angles (i.e., using mean and standard deviation.)
- Attributes:
- meannp.ndarray
The mean values used for scaling.
- stdnp.ndarray
The standard deviation values used for scaling.
- classmethod from_data(joint_angles: ndarray)¶
Create a JointAngleScaler instance from joint angle data. The mean and standard deviation values are calculated from the data.
- Parameters:
- joint_anglesnp.ndarray
The joint angle data. The shape should be (n_samples, n_joints) where n_samples is, for example, the length of a time series of joint angles.
- Returns:
- JointAngleScaler
A JointAngleScaler instance.
- classmethod from_params(mean: ndarray, std: ndarray)¶
Create a JointAngleScaler instance from predetermined mean and standard deviation values.
- Parameters:
- meannp.ndarray
The mean values. The shape should be (n_joints,).
- stdnp.ndarray
The standard deviation values. The shape should be (n_joints,).
- Returns:
- JointAngleScaler
A JointAngleScaler instance.
- class flygym.examples.head_stabilization.WalkingDataset(sim_data_file: Path, contact_force_thr: tuple[float, float, float] = (0.5, 1, 3), joint_angle_scaler: Callable | None = None, ignore_first_n: int = 200, joint_mask=None)¶
PyTorch Dataset class for walking data.
- Parameters:
- sim_data_filePath
The path to the simulation data file.
- contact_force_thrtuple[float, float, float], optional
The threshold values for contact forces, by default (0.5, 1, 3).
- joint_angle_scalerOptional[Callable], optional
A callable object used to scale joint angles, by default None.
- ignore_first_nint, optional
The number of initial data points to ignore, by default 200.
- joint_maskOptional, optional
A mask to apply on joint angles, by default None.
- Attributes:
- gaitstr
The type of gait.
- terrainstr
The type of terrain.
- subsetstr
The subset of the data, i.e., “train” or “test”.
- dn_drivestr
The DN drive used to generate the data.
- contact_force_thrnp.ndarray
The threshold values for contact forces.
- joint_angle_scalerCallable
The callable object used to scale joint angles.
- ignore_first_nint
The number of initial data points to ignore.
- joint_maskOptional
The mask applied on joint angles. This is used to zero out certain DoFs to evaluate which DoFs are likely more important for head stabilization.
- contains_fly_flipbool
Indicates if the simulation data contains fly flip errors.
- contains_physics_errorbool
Indicates if the simulation data contains physics errors.
- roll_pitch_tsnp.ndarray
The optimal roll and pitch correction angles. The shape is (n_samples, 2).
- joint_anglesnp.ndarray
The scaled joint angle time series. The shape is (n_samples, n_joints).
- contact_masknp.ndarray
The contact force mask (i.e., 1 if leg touching the floor, 0 otherwise). The shape is (n_samples, 6).
- class flygym.examples.head_stabilization.ThreeLayerMLP¶
A PyTorch Lightning module for a three-layer MLP that predicts the head roll and pitch correction angles based on proprioception and tactile information.
- configure_optimizers()¶
Use the Adam optimizer.
- forward(x)¶
Forward pass through the model.
- Parameters:
- xtorch.Tensor
The input tensor. The shape should be (n_samples, 42 + 6) where 42 is the number of joint angles and 6 is the number of contact masks.
- training_step(batch, batch_idx)¶
Training step of the PyTorch Lightning module.
- validation_step(batch, batch_idx)¶
Validation step of the PyTorch Lightning module.
- class flygym.examples.head_stabilization.HeadStabilizationInferenceWrapper(model_path: Path, scaler_param_path: Path, contact_force_thr: tuple[float, float, float] = (0.5, 1, 3))¶
Wrapper for the head stabilization model to make predictions on observations. Whereas data are collected in large tensors during training, this class provides a “flat” interface for making predictions one observation (i.e., time step) at a time. This is useful for deploying the model in closed loop.
- Parameters:
- model_pathPath
The path to the trained model.
- scaler_param_pathPath
The path to the pickle file containing scaler parameters.
- contact_force_thrtuple[float, float, float], optional
The threshold values for contact forces that are used to determine the floor contact flags, by default (0.5, 1, 3).
- flygym.examples.head_stabilization.util.get_head_stabilization_model_paths() tuple[Path, Path] ¶
Get the paths to the head stabilization models.
- Returns:
- Path
Path to the head stabilization model checkpoint.
- Path
Path to the pickle file containing joint angle scaler parameters.