Head Stabilization

Model

class flygym.examples.head_stabilization.JointAngleScaler

Bases: object

A class for standardizing joint angles (i.e., using mean and standard deviation.)

Attributes:
meannp.ndarray

The mean values used for scaling.

stdnp.ndarray

The standard deviation values used for scaling.

classmethod from_data(joint_angles: ndarray)

Create a JointAngleScaler instance from joint angle data. The mean and standard deviation values are calculated from the data.

Parameters:
joint_anglesnp.ndarray

The joint angle data. The shape should be (n_samples, n_joints) where n_samples is, for example, the length of a time series of joint angles.

Returns:
JointAngleScaler

A JointAngleScaler instance.

classmethod from_params(mean: ndarray, std: ndarray)

Create a JointAngleScaler instance from predetermined mean and standard deviation values.

Parameters:
meannp.ndarray

The mean values. The shape should be (n_joints,).

stdnp.ndarray

The standard deviation values. The shape should be (n_joints,).

Returns:
JointAngleScaler

A JointAngleScaler instance.

class flygym.examples.head_stabilization.WalkingDataset(sim_data_file: Path, contact_force_thr: tuple[float, float, float] = (0.5, 1, 3), joint_angle_scaler: Callable | None = None, ignore_first_n: int = 200, joint_mask=None)

Bases: Dataset

PyTorch Dataset class for walking data.

Parameters:
sim_data_filePath

The path to the simulation data file.

contact_force_thrtuple[float, float, float], optional

The threshold values for contact forces, by default (0.5, 1, 3).

joint_angle_scalerOptional[Callable], optional

A callable object used to scale joint angles, by default None.

ignore_first_nint, optional

The number of initial data points to ignore, by default 200.

joint_maskOptional, optional

A mask to apply on joint angles, by default None.

Attributes:
gaitstr

The type of gait.

terrainstr

The type of terrain.

subsetstr

The subset of the data, i.e., “train” or “test”.

dn_drivestr

The DN drive used to generate the data.

contact_force_thrnp.ndarray

The threshold values for contact forces.

joint_angle_scalerCallable

The callable object used to scale joint angles.

ignore_first_nint

The number of initial data points to ignore.

joint_maskOptional

The mask applied on joint angles. This is used to zero out certain DoFs to evaluate which DoFs are likely more important for head stabilization.

contains_fly_flipbool

Indicates if the simulation data contains fly flip errors.

contains_physics_errorbool

Indicates if the simulation data contains physics errors.

roll_pitch_tsnp.ndarray

The optimal roll and pitch correction angles. The shape is (n_samples, 2).

joint_anglesnp.ndarray

The scaled joint angle time series. The shape is (n_samples, n_joints).

contact_masknp.ndarray

The contact force mask (i.e., 1 if leg touching the floor, 0 otherwise). The shape is (n_samples, 6).

class flygym.examples.head_stabilization.ThreeLayerMLP

Bases: LightningModule

A PyTorch Lightning module for a three-layer MLP that predicts the head roll and pitch correction angles based on proprioception and tactile information.

configure_optimizers()

Use the Adam optimizer.

forward(x)

Forward pass through the model.

Parameters:
xtorch.Tensor

The input tensor. The shape should be (n_samples, 42 + 6) where 42 is the number of joint angles and 6 is the number of contact masks.

training_step(batch, batch_idx)

Training step of the PyTorch Lightning module.

validation_step(batch, batch_idx)

Validation step of the PyTorch Lightning module.

class flygym.examples.head_stabilization.HeadStabilizationInferenceWrapper(model_path: Path, scaler_param_path: Path, contact_force_thr: tuple[float, float, float] = (0.5, 1, 3))

Bases: object

Wrapper for the head stabilization model to make predictions on observations. Whereas data are collected in large tensors during training, this class provides a “flat” interface for making predictions one observation (i.e., time step) at a time. This is useful for deploying the model in closed loop.

Parameters:
model_pathPath

The path to the trained model.

scaler_param_pathPath

The path to the pickle file containing scaler parameters.

contact_force_thrtuple[float, float, float], optional

The threshold values for contact forces that are used to determine the floor contact flags, by default (0.5, 1, 3).

Utilities

flygym.examples.head_stabilization.util.get_head_stabilization_model_paths() tuple[Path, Path]

Get the paths to the head stabilization models.

Returns:
Path

Path to the head stabilization model checkpoint.

Path

Path to the pickle file containing joint angle scaler parameters.