Head stabilization¶
Note
Author: Sibo Wang-Chen
The code presented in this notebook has been simplified and restructured for display in a notebook format. A more complete and better structured implementation can be found in the examples folder of the FlyGym repository on GitHub.
This tutorial is available in .ipynb
format in the
notebooks folder of the FlyGym repository.
Summary: In this tutorial, we will use mechanosensory information to correct for self motion in closed loop. We will train an internal model that predicts the appropriate neck actuation signals, based on leg joint angles and ground contacts, to minimize head rotations during walking.
Introduction¶
In the previous tutorial, we demonstrated how one can integrate ascending mechanosensory information to estimate the fly’s position. In this tutorial, we will demonstrate another way in which the fly uses ascending motor signals to complete the sensorimotor control loop.
In flies, head stabilization has been shown to be used to compensate for body pitch and roll (Kress & Egelhaaf, 2012) during locomotion. It is thought that these stabilizing movements may be informed by leg sensory feedback signals (Gollin & Dürr, 2018). To explore head stabilization in our embodied model, we will design a controller in which leg joint angles (i.e., proprioceptive signals, 6 legs × 7 degrees of freedom per leg) and ground contacts (i.e., tactile signals, 6 legs) are given as inputs to a multilayer perceptron (MLP). This model, in turn, predicts the appropriate head roll and pitch required to cancel visual rotations caused by the animal’s own body movements during walking. We will use these signals to actuate the neck joint and aim to dampen head rotation. This approach is illustrated as follows:
Collecting training data¶
We start by running short simulations of walking while recording joint angles, ground contacts, and head rotations. This will give us a set of input-output pairs to use as training data. To run the simulations, we implement the following function:
import numpy as np
import pickle
import cv2
from tqdm import trange
from pathlib import Path
from typing import Optional
from dm_control.utils import transformations
from dm_control.rl.control import PhysicsError
from flygym import Fly, Camera
from flygym.arena import FlatTerrain, BlocksTerrain
from flygym.preprogrammed import get_cpg_biases
from flygym.examples.locomotion import HybridTurningController
def run_simulation(
gait: str = "tripod",
terrain: str = "flat",
spawn_xy: tuple[float, float] = (0, 0),
dn_drive: tuple[float, float] = (1, 1),
sim_duration: float = 0.5,
enable_rendering: bool = False,
live_display: bool = False,
output_dir: Optional[Path] = None,
pbar: bool = False,
):
"""Simulate locomotion and collect proprioceptive information to train
a neural network for head stabilization.
Parameters
----------
gait : str, optional
The type of gait for the fly. Choose from ['tripod', 'tetrapod',
'wave']. Defaults to "tripod".
terrain : str, optional
The type of terrain for the fly. Choose from ['flat', 'blocks'].
Defaults to "flat".
spawn_xy : tuple[float, float], optional
The x and y coordinates of the fly's spawn position. Defaults to
(0, 0).
dn_drive : tuple[float, float], optional
The DN drive values for the left and right wings. Defaults to
(1, 1).
sim_duration : float, optional
The duration of the simulation in seconds. Defaults to 0.5.
enable_rendering: bool, optional
If True, enables rendering. Defaults to False.
live_display : bool, optional
If True, enables live display. Defaults to False.
output_dir : Path, optional
The directory to which output files are saved. Defaults to None.
pbar : bool, optional
If True, enables progress bar. Defaults to False.
Raises
------
ValueError
Raised when an unknown terrain type is provided.
"""
if (not enable_rendering) and live_display:
raise ValueError("Cannot enable live display without rendering.")
# Set up arena
if terrain == "flat":
arena = FlatTerrain()
elif terrain == "blocks":
arena = BlocksTerrain(height_range=(0.2, 0.2))
else:
raise ValueError(f"Unknown terrain type: {terrain}")
# Set up simulation
contact_sensor_placements = [
f"{leg}{segment}"
for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
]
fly = Fly(
enable_adhesion=True,
draw_adhesion=True,
detect_flip=True,
contact_sensor_placements=contact_sensor_placements,
spawn_pos=(*spawn_xy, 0.25),
)
cam = Camera(
fly=fly, camera_id="Animat/camera_left", play_speed=0.1, timestamp_text=True
)
sim = HybridTurningController(
arena=arena,
phase_biases=get_cpg_biases(gait),
fly=fly,
cameras=[cam],
timestep=1e-4,
)
# Main simulation loop
obs, info = sim.reset(0)
obs_hist, info_hist, action_hist = [], [], []
dn_drive = np.array(dn_drive)
physics_error, fly_flipped = False, False
iterator = trange if pbar else range
for _ in iterator(int(sim_duration / sim.timestep)):
action_hist.append(dn_drive)
try:
obs, reward, terminated, truncated, info = sim.step(dn_drive)
except PhysicsError:
print("Physics error detected!")
physics_error = True
break
if enable_rendering:
rendered_img = sim.render()[0]
# Get necessary angles
quat = sim.physics.bind(sim.fly.thorax).xquat
quat_inv = transformations.quat_inv(quat)
roll, pitch, yaw = transformations.quat_to_euler(quat_inv, ordering="XYZ")
info["roll"], info["pitch"], info["yaw"] = roll, pitch, yaw
obs_hist.append(obs)
info_hist.append(info)
if info["flip"]:
print("Flip detected!")
break
# Live display
if enable_rendering and live_display and rendered_img is not None:
cv2.imshow("rendered_img", rendered_img[:, :, ::-1])
cv2.waitKey(1)
# Save data if output_dir is provided
if output_dir is not None:
output_dir.mkdir(parents=True, exist_ok=True)
if enable_rendering:
cam.save_video(output_dir / "rendering.mp4")
with open(output_dir / "sim_data.pkl", "wb") as f:
data = {
"obs_hist": obs_hist,
"info_hist": info_hist,
"action_hist": action_hist,
"errors": {
"fly_flipped": fly_flipped,
"physics_error": physics_error,
},
}
pickle.dump(data, f)
With this function, we will run a short simulation using the descending drive [1.0, 1.0] to walk straight:
output_dir = Path("outputs/head_stabilization/")
output_dir.mkdir(parents=True, exist_ok=True)
run_simulation(
gait="tripod",
terrain="flat",
spawn_xy=(0, 0),
dn_drive=(1, 1),
sim_duration=0.5,
enable_rendering=True,
live_display=False,
output_dir=output_dir / "tripod_flat_train_set_1.00_1.00",
pbar=True,
)
100%|██████████| 5000/5000 [00:14<00:00, 338.80it/s]
As a sanity check, we can plot the trajectory of the fly:
import matplotlib.pyplot as plt
with open(output_dir / "tripod_flat_train_set_1.00_1.00/sim_data.pkl", "rb") as f:
sim_data_flat = pickle.load(f)
trajectory = np.array([obs["fly"][0, :2] for obs in sim_data_flat["obs_hist"]])
fig, ax = plt.subplots(figsize=(5, 2), tight_layout=True)
ax.plot(trajectory[:, 0], trajectory[:, 1], label="Trajectory")
ax.plot([0], [0], "ko", label="Origin")
ax.legend()
ax.set_aspect("equal")
ax.set_xlabel("x position (mm)")
ax.set_ylabel("y position (mm)")
fig.savefig(output_dir / "head_stabilization_trajectory_sample.png")
We can also plot the time series of the variables that we are interested in, namely:
Joint angles of all leg degrees of freedom (DoFs), 7 real values per leg per step
Leg contact mask, 1 Boolean value per leg per step
The appropriate neck roll needed to cancel out body rotation, 1 real value per step
The appropriate neck pitch needed to cancel out body rotation, 1 real value per step
Note that we do not correct for rotation on the yaw axis. This is to avoid delineating unintended body oscillation the from intentional turning — a task outside the scope of this tutorial.
To get the leg contacts, we will use a contact force threshold of 0.5 mN for the front legs, 1 mN for the middle legs, and 3 mN for the hind legs — as was the case in the path integration tutorial.
To get the appropriate neck roll and pitch needed to cancel out body rotations, we will take the quaternion representing the thorax rotation, invert it, and convert it to Euler angles. Quaternions are a mathematical concept used to represent rotations in three dimensions. They avoid some of the pitfalls of other rotation representations, such as gimbal lock. However, quaternions are less intuitive to interpret and their elements do not directly correspond to the axes on the fly body. Therefore, we convert the inverted angles to Euler angles with more familiar axes of rotation (pitch, roll, yaw). More information about representation of 3D rotation can be found on this Wikipedia article.
For simplicity of visualization, we will only plot the legs on the left side:
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
dofs_per_leg = [
"ThC pitch",
"ThC roll",
"ThC yaw",
"CTr pitch",
"CTr roll",
"FTi pitch",
"TiTa pitch",
]
contact_force_thr = np.array([0.5, 1.0, 3.0, 0.5, 1.0, 3.0]) # LF LM LH RF RM RH
def visualize_trial_data(obs_hist, info_hist, output_path):
t_grid = np.arange(len(obs_hist)) * 1e-4
# Extract joint angles
joint_angles = np.array([obs["joints"][0, :] for obs in obs_hist])
# Extract ground contact
contact_forces = np.array([obs["contact_forces"] for obs in obs_hist])
# get magnitude from xyz vector:
contact_forces = np.linalg.norm(contact_forces, axis=2)
# sum over 6 segments per leg (contact sensing enabled for tibia and 5 tarsal segments):
contact_forces = contact_forces.reshape(-1, 6, 6).sum(axis=2)
contact_mask = contact_forces >= contact_force_thr
# Extract head rotation
roll = np.array([info["roll"] for info in info_hist])
pitch = np.array([info["pitch"] for info in info_hist])
# Visualize
fig, axs = plt.subplots(
6, 1, figsize=(6, 9), tight_layout=True, height_ratios=[3, 3, 3, 2, 3, 1]
)
# Legs
for i, leg in enumerate(["Left front leg", "Left middle leg", "Left hind leg"]):
ax = axs[i]
# Plot joint angles
for j, dof in enumerate(dofs_per_leg):
dof_idx = i * len(dofs_per_leg) + j
ax.plot(t_grid, np.rad2deg(joint_angles[:, dof_idx]), label=dof, lw=1)
ax.set_title(leg)
ax.set_ylabel(r"Joint angle ($^\circ$)")
ax.set_ylim(-180, 180)
ax.set_yticks([-180, -90, 0, 90, 180])
# Plot ground contact
bool_ts = contact_mask[:, i]
diff_ts = np.diff(bool_ts.astype(int), prepend=0)
if bool_ts[0]:
diff_ts[0] = 1
if bool_ts[-1]:
diff_ts[-1] = -1
upedges = np.where(diff_ts == 1)[0]
downedges = np.where(diff_ts == -1)[0]
for up, down in zip(upedges, downedges):
ax.axvspan(
t_grid[up],
t_grid[down],
color="black",
alpha=0.2,
lw=0,
label="Ground contact",
)
ax.set_xlabel("Time (s)")
# Leg legends
legend_elements = []
for j, dof in enumerate(dofs_per_leg):
legend_elements.append(Line2D([0], [0], color=f"C{j}", lw=1, label=dof))
legend_elements.append(
Patch(color="black", alpha=0.2, lw=0, label="Ground contact")
)
axs[3].legend(
bbox_to_anchor=(0, 1.1, 1, 0.2),
handles=legend_elements,
loc="upper center",
ncols=3,
mode="expand",
frameon=False,
)
axs[3].axis("off")
# Head movement
ax = axs[4]
ax.plot(t_grid, np.rad2deg(roll), label="Head roll", lw=2, color="midnightblue")
ax.plot(t_grid, np.rad2deg(pitch), label="Head pitch", lw=2, color="saddlebrown")
ax.set_title("Head movement")
ax.set_ylabel(r"Angle ($^\circ$)")
ax.set_ylim(-20, 20)
ax.set_xlabel("Time (s)")
# Head legends
legend_elements = [
Line2D([0], [0], color=f"midnightblue", lw=2, label="Roll"),
Line2D([0], [0], color=f"saddlebrown", lw=2, label="Pitch"),
]
axs[5].legend(
bbox_to_anchor=(0, 1.4, 1, 0.2),
handles=legend_elements,
loc="upper center",
ncols=2,
mode="expand",
frameon=False,
)
axs[5].axis("off")
fig.savefig(output_path)
visualize_trial_data(
sim_data_flat["obs_hist"],
sim_data_flat["info_hist"],
output_dir / "head_stabilization_flat_terrain_ts_sample.png",
)
We observe that, after about 0.1 seconds of transient response, we can indeed see the gait cycles from the input variables.
If we run another simulation over rugged terrain, the body oscillations appear more dramatic:
run_simulation(
gait="tripod",
terrain="blocks",
spawn_xy=(0, 0),
dn_drive=(1, 1),
sim_duration=0.5,
enable_rendering=True,
live_display=False,
output_dir=output_dir / "tripod_blocks_train_set_1.00_1.00",
pbar=True,
)
100%|██████████| 5000/5000 [00:21<00:00, 235.63it/s]
with open(output_dir / "tripod_blocks_train_set_1.00_1.00/sim_data.pkl", "rb") as f:
sim_data_blocks = pickle.load(f)
visualize_trial_data(
sim_data_blocks["obs_hist"],
sim_data_blocks["info_hist"],
output_dir / "head_stabilization_blocks_terrain_ts_sample.png",
)
Training an internal model to control neck actuation¶
In the previous section, we have extracted the ascending sensory signals and the target motor outputs that are the model’s inputs and outputs. Now, we will train a multilayer perceptron (MLP) that predicts the appropriate neck actuation signals using this ascending mechanosensory information. We will split this task into three technical steps:
Implementing a custom PyTorch dataset class to feed our data, through a dataloader, into the model
Defining an MLP with three hidden layers
Training the MLP using the data we have gathered and the data pipeline that we will have developed
Implementing a custom PyTorch dataset¶
When training any machine learning or statistical model, it is often
desired to normalize or standardize the input. We will start by
implementing a JointAngleScaler
class to do standardize joint angle
data (subtract mean, divide by standard deviation). This class can be
initialized in one of two ways:
A
.from_data
method that calculates the mean and standard deviation from a given dataset.A
.from_params
method that uses given user-specified mean and and standard deviation.
This way, we can compute the mean and standard deviation from one trial and use the same parameters on all datasets.
class JointAngleScaler:
"""
A class for standardizing joint angles (i.e., using mean and standard
deviation.
Attributes
----------
mean : np.ndarray
The mean values used for scaling.
std : np.ndarray
The standard deviation values used for scaling.
"""
@classmethod
def from_data(cls, joint_angles: np.ndarray):
"""
Create a JointAngleScaler instance from joint angle data. The mean
and standard deviation values are calculated from the data.
Parameters
----------
joint_angles : np.ndarray
The joint angle data. The shape should be (n_samples, n_joints)
where n_samples is, for example, the length of a time series of
joint angles.
Returns
-------
JointAngleScaler
A JointAngleScaler instance.
"""
scaler = cls()
scaler.mean = np.mean(joint_angles, axis=0)
scaler.std = np.std(joint_angles, axis=0)
return scaler
@classmethod
def from_params(cls, mean: np.ndarray, std: np.ndarray):
"""
Create a JointAngleScaler instance from predetermined mean and
standard deviation values.
Parameters
----------
mean : np.ndarray
The mean values. The shape should be (n_joints,).
std : np.ndarray
The standard deviation values. The shape should be (n_joints,).
Returns
-------
JointAngleScaler
A JointAngleScaler instance.
"""
scaler = cls()
scaler.mean = mean
scaler.std = std
return scaler
def __call__(self, joint_angles: np.ndarray):
"""
Scale the given joint angles.
Parameters
----------
joint_angles : np.ndarray
The joint angles to be scaled. The shape should be (n_samples,
n_joints) where n_samples is, for example, the length of a time
series of joint angles.
Returns
-------
np.ndarray
The scaled joint angles.
"""
return (joint_angles - self.mean) / self.std
Then, we will construct a PyTorch dataset class. This class can be seen as an “adapter”: on one side, it interfaces the specifics of our data (data structure, format, etc.); on the other side, it outputs what PyTorch models expect, so that the neural network can work with it. See this tutorial from Pytorch for more details on the Dataset interface.
from torch.utils.data import Dataset
from typing import Optional, Callable
class WalkingDataset(Dataset):
"""
PyTorch Dataset class for walking data.
Parameters
----------
sim_data_file : Path
The path to the simulation data file.
contact_force_thr : tuple[float, float, float], optional
The threshold values for contact forces, by default (0.5, 1, 3).
joint_angle_scaler : Optional[Callable], optional
A callable object used to scale joint angles, by default None.
ignore_first_n : int, optional
The number of initial data points to ignore, by default 200.
joint_mask : Optional, optional
A mask to apply on joint angles, by default None.
Attributes
----------
gait : str
The type of gait.
terrain : str
The type of terrain.
subset : str
The subset of the data, i.e., "train" or "test".
dn_drive : str
The DN drive used to generate the data.
contact_force_thr : np.ndarray
The threshold values for contact forces.
joint_angle_scaler : Callable
The callable object used to scale joint angles.
ignore_first_n : int
The number of initial data points to ignore.
joint_mask : Optional
The mask applied on joint angles. This is used to zero out certain
DoFs to evaluate which DoFs are likely more important for head
stabilization.
contains_fly_flip : bool
Indicates if the simulation data contains fly flip errors.
contains_physics_error : bool
Indicates if the simulation data contains physics errors.
roll_pitch_ts : np.ndarray
The optimal roll and pitch correction angles. The shape is
(n_samples, 2).
joint_angles : np.ndarray
The scaled joint angle time series. The shape is (n_samples,
n_joints).
contact_mask : np.ndarray
The contact force mask (i.e., 1 if leg touching the floor, 0
otherwise). The shape is (n_samples, 6).
"""
def __init__(
self,
sim_data_file: Path,
contact_force_thr: tuple[float, float, float] = (0.5, 1, 3),
joint_angle_scaler: Optional[Callable] = None,
ignore_first_n: int = 200,
joint_mask=None,
) -> None:
super().__init__()
trial_name = sim_data_file.parent.name
gait, terrain, subset, _, dn_left, dn_right = trial_name.split("_")
self.gait = gait
self.terrain = terrain
self.subset = subset
self.dn_drive = f"{dn_left}_{dn_right}"
self.contact_force_thr = np.array([*contact_force_thr, *contact_force_thr])
self.joint_angle_scaler = joint_angle_scaler
self.ignore_first_n = ignore_first_n
self.joint_mask = joint_mask
with open(sim_data_file, "rb") as f:
sim_data = pickle.load(f)
self.contains_fly_flip = sim_data["errors"]["fly_flipped"]
self.contains_physics_error = sim_data["errors"]["physics_error"]
# Extract the roll and pitch angles
roll = np.array([info["roll"] for info in sim_data["info_hist"]])
pitch = np.array([info["pitch"] for info in sim_data["info_hist"]])
self.roll_pitch_ts = np.stack([roll, pitch], axis=1)
self.roll_pitch_ts = self.roll_pitch_ts[self.ignore_first_n :, :]
# Extract joint angles and scale them
joint_angles_raw = np.array(
[obs["joints"][0, :] for obs in sim_data["obs_hist"]]
)
if self.joint_angle_scaler is None:
self.joint_angle_scaler = JointAngleScaler.from_data(joint_angles_raw)
self.joint_angles = self.joint_angle_scaler(joint_angles_raw)
self.joint_angles = self.joint_angles[self.ignore_first_n :, :]
# Extract contact forces
contact_forces = np.array(
[obs["contact_forces"] for obs in sim_data["obs_hist"]]
)
contact_forces = np.linalg.norm(contact_forces, axis=2) # magnitude
contact_forces = contact_forces.reshape(-1, 6, 6).sum(axis=2) # sum per leg
self.contact_mask = (contact_forces >= self.contact_force_thr).astype(np.int16)
self.contact_mask = self.contact_mask[self.ignore_first_n :, :]
def __len__(self):
return self.roll_pitch_ts.shape[0]
def __getitem__(self, idx):
joint_angles = self.joint_angles[idx].astype(np.float32, copy=True)
if self.joint_mask is not None:
joint_angles[~self.joint_mask] = 0
return {
"roll_pitch": self.roll_pitch_ts[idx].astype(np.float32),
"joint_angles": joint_angles,
"contact_mask": self.contact_mask[idx].astype(np.float32),
}
We can test the joint angle scaler and dataset classes using our trial simulation:
joint_angles = np.array([obs["joints"][0, :] for obs in sim_data_flat["obs_hist"]])
joint_scaler = JointAngleScaler.from_data(joint_angles)
dataset = WalkingDataset(
sim_data_file=output_dir / "tripod_flat_train_set_1.00_1.00/sim_data.pkl",
joint_angle_scaler=joint_scaler,
ignore_first_n=200,
)
with open(output_dir / "head_stabilization_joint_angle_scaler_params.pkl", "wb") as f:
pickle.dump({"mean": joint_scaler.mean, "std": joint_scaler.std}, f)
Let’s plot the joint angles for the left front leg again, but using the
dataset as an iterator instead of the output returned by
run_simulation
:
t_grid = np.arange(200, 200 + len(dataset)) * 1e-4
joint_angles = np.array([entry["joint_angles"] for entry in dataset])
fig, ax = plt.subplots(figsize=(6, 3), tight_layout=True)
ax.axhline(0, color="black", lw=1)
ax.axhspan(-1, 1, color="black", alpha=0.2, lw=0)
for i, dof in enumerate(dofs_per_leg):
ax.plot(t_grid, joint_angles[:, i], label=dof, lw=1)
ax.legend(
bbox_to_anchor=(0, 1.02, 1, 0.2),
loc="lower left",
mode="expand",
borderaxespad=0,
ncol=4,
)
ax.set_xlim(0, 0.5)
ax.set_ylim(-3, 3)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Standardized joint angle (AU)")
fig.savefig(output_dir / "head_stabilization_joint_angles_scaled.png")
We observe that the joint angles now share a mean of 0 (black line) and standard deviation of 1 (gray shade).
We can further use the PyTorch dataloader to fetch data in batches. This is useful for training the MLP in the next step. As an example, we can create a dataset that gives us a shuffled batch of 32 samples at a time:
from torch.utils.data import DataLoader
example_loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in example_loader:
for key, value in batch.items():
print(f"{key}\tshape: {value.shape}")
break
roll_pitch shape: torch.Size([32, 2])
joint_angles shape: torch.Size([32, 42])
contact_mask shape: torch.Size([32, 6])
Defining an MLP¶
Having implemented the data pipeline, we will now define the model itself. We will use PyTorch Lightning, a framework built on top of PyTorch that simplifies checkpointing (saving snapshots of model parameters during training), logging, etc.
In brief, our ThreeLayerMLP
class, implemented below, consists of
the following:
An
__init__
method that creates three hidden layers and aR2Score
object that calculates the \(R^2\) score.A
forward
method that implements the forward pass of the neural network — a process where we traverse layers in the network to calculate values of the output layer based on the input. In our case, we simply apply the three hidden layers sequentially, with a Rectified Linear Unit (ReLU) activation function at the end of the first two layers. Based on this method, PyTorch will automatically implement the backward pass — a process in gradient-based optimization algorithms where, after the forward pass, the gradients for parameters in all layers are traced, starting from the gradient of the loss on the outputs (i.e., last layer).A
configure_optimizer
method that sets up the optimizer — in our case an Adam optimizer with a learning rate of 0.001.A
training_step
method that defines the operation to be conducted for each training step (i.e. every time the model receives a new batch of training data). Here, we concatenate the joint angles and leg contact masks into a single input block, run the forward pass (we can simply call the module itself on in the input for this), and calculate the MSE loss. Then, we log the loss as training loss and return it. PyTorch Lightning will do the backpropagation for us.A
validation_step
method that defines what the model should do every time a batch of validation data is received. Similar totraining_step
, we run the forward pass, but this time we calculate the \(R^2\) scores in addition to the MSE loss. Lastly, we log the \(R^2\) and MSE metrics accordingly.
For more information on implementing a PyTorch Lightning module, see this tutorial.
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as pl
from torchmetrics.regression import R2Score
pl.seed_everything(0, workers=True)
class ThreeLayerMLP(pl.LightningModule):
"""
A PyTorch Lightning module for a three-layer MLP that predicts the
head roll and pitch correction angles based on proprioception and
tactile information.
"""
def __init__(self):
super().__init__()
input_size = 42 + 6
hidden_size = 32
output_size = 2
self.layer1 = nn.Linear(input_size, hidden_size)
self.layer2 = nn.Linear(hidden_size, hidden_size)
self.layer3 = nn.Linear(hidden_size, output_size)
self.r2_score = R2Score()
def forward(self, x):
"""
Forward pass through the model.
Parameters
----------
x : torch.Tensor
The input tensor. The shape should be (n_samples, 42 + 6)
where 42 is the number of joint angles and 6 is the number of
contact masks.
"""
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
return self.layer3(x)
def configure_optimizers(self):
"""Use the Adam optimizer."""
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, batch, batch_idx):
"""Training step of the PyTorch Lightning module."""
x = torch.concat([batch["joint_angles"], batch["contact_mask"]], dim=1)
y = batch["roll_pitch"]
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
"""Validation step of the PyTorch Lightning module."""
x = torch.concat([batch["joint_angles"], batch["contact_mask"]], dim=1)
y = batch["roll_pitch"]
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
self.log("val_loss", loss)
if y.shape[0] > 1:
r2_roll = self.r2_score(y_hat[:, 0], y[:, 0])
r2_pitch = self.r2_score(y_hat[:, 1], y[:, 1])
else:
r2_roll, r2_pitch = np.nan, np.nan
self.log("val_r2_roll", r2_roll)
self.log("val_r2_pitch", r2_pitch)
INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
Training the model¶
Having implemented the data pipeline and defined the model, we will now train the model. We have pre-generated 126 simulation trials, including 11 training trials and 10 testing trials with different descending drives, for each of the three gait patterns (tripod gait, tetrapod gait, and wave gait), and for flat and blocks terrain types. Of these, we exclude one simulation (wave gait, blocks terrain, test set, DN drives [0.58, 1.14]) because the fly flipped while walking. You can download this dataset by running the code block below.
# TODO. We are working with our IT team to set up a gateway to share these data publicly
# in a secure manner. We aim to update this by the end of June, 2024. Please reach out
# to us by email in the meantime.
simulation_data_dir = (
Path.home() / "Data/flygym_demo_data/head_stabilization/random_exploration/"
)
if not simulation_data_dir.is_dir():
raise FileNotFoundError(
"Pregenerated simulation data not found. Please download it from TODO."
)
else:
print(f"[OK] Pregenerated simulation data found. Ready to proceed.")
[OK] Pregenerated simulation data found. Ready to proceed.
Let’s generate a WalkingDataset
object (implemented above) for each
training trial and concatenate them.
from torch.utils.data import ConcatDataset
dataset_list = []
for gait in ["tripod", "tetrapod", "wave"]:
for terrain in ["flat", "blocks"]:
paths = simulation_data_dir.glob(f"{gait}_{terrain}_train_set_*")
print(f"Loading {gait} gait, {terrain} terrain...")
dn_drives = ["_".join(p.name.split("_")[-2:]) for p in paths]
for dn_drive in dn_drives:
sim = f"{gait}_{terrain}_train_set_{dn_drive}"
path = simulation_data_dir / f"{sim}/sim_data.pkl"
ds = WalkingDataset(path, joint_angle_scaler=joint_scaler)
ds.joint_mask = np.ones(42, dtype=bool) # use all joints
dataset_list.append(ds)
concat_train_set = ConcatDataset(dataset_list)
print(f"Training dataset size: {len(dataset)}")
Loading tripod gait, flat terrain...
Loading tripod gait, blocks terrain...
Loading tetrapod gait, flat terrain...
Loading tetrapod gait, blocks terrain...
Loading wave gait, flat terrain...
Loading wave gait, blocks terrain...
Training dataset size: 4800
The size is as expected: (3 gaits × 2 terrain types × 11 DN combinations) × (0.5 seconds of simulation / 0.0001 seconds per step – 200 transient steps excluded) = 976,800 samples in total.
We will further divide the training set into the training set a validation set at a ratio of 4:1:
The training set is used to optimize the parameters of the model.
The validation set is used to check if the model has been overfitted.
The testing set is held out throughout the entire training procedure. It consists of trials simulated using a different set of descending drives and is only used to report the final out-of-sample performance of the model.
from torch.utils.data import random_split
train_ds, val_ds = random_split(concat_train_set, [0.8, 0.2])
As demonstrated above, we will create dataloaders for the training and validation sets to load the data in batches:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_ds, batch_size=256, num_workers=4, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1028, num_workers=4, shuffle=False)
Finally, we will set up a logger to keep track of the training progress, a checkpoint callback that saves snapshots of model parameters while training, and a trainer object to orchestrate the training procedure:
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from shutil import rmtree
log_dir = Path(output_dir / "logs")
if log_dir.is_dir():
rmtree(log_dir)
logger = CSVLogger(log_dir, name="demo_trial")
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
dirpath=output_dir / "models/checkpoints",
filename="%s-{epoch:02d}-{val_loss:.2f}",
save_top_k=1, # Save only the best checkpoint
mode="min", # `min` for minimizing the validation loss
)
model = ThreeLayerMLP()
trainer = pl.Trainer(
logger=logger,
callbacks=[checkpoint_callback],
max_epochs=10,
check_val_every_n_epoch=1,
deterministic=True,
)
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
We are now ready to train the model. We will train the model for 10 epochs. On a machine with a NVIDIA GeForce RTX 3080 Ti GPU (2021), this takes about 2 minutes.
trainer.fit(model, train_loader, val_loader)
WARNING: Missing logger folder: outputs/logs/demo_trial
WARNING:lightning.fabric.loggers.csv_logs:Missing logger folder: outputs/logs/demo_trial
INFO:
| Name | Type | Params
-------------------------------------
0 | layer1 | Linear | 1.6 K
1 | layer2 | Linear | 1.1 K
2 | layer3 | Linear | 66
3 | r2_score | R2Score | 0
-------------------------------------
2.7 K Trainable params
0 Non-trainable params
2.7 K Total params
0.011 Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
| Name | Type | Params
-------------------------------------
0 | layer1 | Linear | 1.6 K
1 | layer2 | Linear | 1.1 K
2 | layer3 | Linear | 66
3 | r2_score | R2Score | 0
-------------------------------------
2.7 K Trainable params
0 Non-trainable params
2.7 K Total params
0.011 Total estimated model params size (MB)
INFO: Trainer.fit stopped: max_epochs=10 reached. INFO:lightning.pytorch.utilities.rank_zero:Trainer.fit stopped: max_epochs=10 reached.
Let’s inspect the model’s performance on the training and validation sets changed over time. On the validation set, we will plot the loss and \(R^2\) scores at the end of each epoch.
import pandas as pd
logs = pd.read_csv(log_dir / "demo_trial/version_0/metrics.csv")
fig, axs = plt.subplots(2, 1, figsize=(5, 5), tight_layout=True, sharex=True)
ax = axs[0]
mask = np.isfinite(logs["train_loss"])
ax.plot(logs["step"][mask], logs["train_loss"][mask], label="Training loss")
mask = np.isfinite(logs["val_loss"])
ax.plot(logs["step"][mask], logs["val_loss"][mask], label="Validation loss", marker="o")
ax.legend()
ax.set_ylabel("MSE loss")
ax = axs[1]
ax.plot(
logs["step"][mask],
logs["val_r2_roll"][mask],
color="midnightblue",
label="Roll",
marker="o",
)
ax.plot(
logs["step"][mask],
logs["val_r2_pitch"][mask],
color="saddlebrown",
label="Pitch",
marker="o",
)
ax.legend(loc="lower right")
ax.set_xlabel("Step")
ax.set_ylabel("R² score")
fig.savefig(output_dir / "head_stabilization_training_metrics.png")
Satisfied with the performance, we now proceed to evaluate the model on the testing set and deploy it in closed loop.
Deploying the model¶
While the PyTorch module ThreeLayerMLP
can give us predictions, it
is not very lean: a number of training-related elements are exposed to
the caller. For example, the forward
method expects a batch of
data concatenated in a specific way, and PyTorch will try to load it on
an accelerated hardware automatically if one is found. This is not ideal
for real time deployment — we will only get one input snapshot at a
time and the data is small enough and the steps frequent enough that it
not worth loading/unloading data to the GPU every step. Therefore, as a
next step, we will write a wrapper that provides a minimal interface
that simplifies making single-step predictions natively on the CPU:
class HeadStabilizationInferenceWrapper:
"""
Wrapper for the head stabilization model to make predictions on
observations. Whereas data are collected in large tensors during
training, this class provides a "flat" interface for making predictions
one observation (i.e., time step) at a time. This is useful for
deploying the model in closed loop.
"""
def __init__(
self,
model_path: Path,
scaler_param_path: Path,
contact_force_thr: tuple[float, float, float] = (0.5, 1, 3),
):
"""
Parameters
----------
model_path : Path
The path to the trained model.
scaler_param_path : Path
The path to the pickle file containing scaler parameters.
contact_force_thr : tuple[float, float, float], optional
The threshold values for contact forces that are used to
determine the floor contact flags, by default (0.5, 1, 3).
"""
# Load scaler params
with open(scaler_param_path, "rb") as f:
scaler_params = pickle.load(f)
self.scaler_mean = scaler_params["mean"]
self.scaler_std = scaler_params["std"]
# Load model
# it's not worth moving data to the GPU, just run it on the CPU
self.model = ThreeLayerMLP.load_from_checkpoint(
model_path, map_location=torch.device("cpu")
)
self.contact_force_thr = np.array([*contact_force_thr, *contact_force_thr])
def __call__(
self, joint_angles: np.ndarray, contact_forces: np.ndarray
) -> np.ndarray:
"""
Make a prediction given joint angles and contact forces. This is
a light wrapper around the model's forward method and works without
batching.
Parameters
----------
joint_angles : np.ndarray
The joint angles. The shape should be (n_joints,).
contact_forces : np.ndarray
The contact forces. The shape should be (n_legs * n_segments).
Returns
-------
np.ndarray
The predicted roll and pitch angles. The shape is (2,).
"""
joint_angles = (joint_angles - self.scaler_mean) / self.scaler_std
contact_forces = np.linalg.norm(contact_forces, axis=1)
contact_forces = contact_forces.reshape(6, 6).sum(axis=1)
contact_mask = contact_forces >= self.contact_force_thr
x = np.concatenate([joint_angles, contact_mask], dtype=np.float32)
input_tensor = torch.tensor(x[None, :], device=torch.device("cpu"))
output_tensor = self.model(input_tensor)
return output_tensor.detach().numpy().squeeze()
Let’s load the model from the saved checkpoint:
model_wrapper = HeadStabilizationInferenceWrapper(
model_path=checkpoint_callback.best_model_path,
scaler_param_path=output_dir / "head_stabilization_joint_angle_scaler_params.pkl",
)
To deploy the head stabilization model in closed loop, we will write a
run_simulation_closed_loop
function:
from flygym.arena import BaseArena
from sklearn.metrics import r2_score
contact_sensor_placements = [
f"{leg}{segment}"
for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
]
def run_simulation_closed_loop(
arena: BaseArena,
run_time: float = 0.5,
head_stabilization_model: Optional[HeadStabilizationInferenceWrapper] = None,
):
fly = Fly(
contact_sensor_placements=contact_sensor_placements,
vision_refresh_rate=500,
neck_kp=500,
head_stabilization_model=head_stabilization_model,
)
sim = HybridTurningController(fly=fly, arena=arena)
sim.reset(seed=0)
# These are updated at every time step and are used for generating
# statistics and plots (except vision_all, which is updated every
# time step where the visual input is updated. Visual updates are less
# frequent than physics steps).
head_rotation_hist = []
thorax_rotation_hist = []
neck_actuation_pred_hist = [] # model-predicted neck actuation
neck_actuation_true_hist = [] # ideal neck actuation
thorax_body = fly.model.find("body", "Thorax")
head_body = fly.model.find("body", "Head")
# Main simulation loop
for i in trange(int(run_time / sim.timestep)):
try:
obs, _, _, _, info = sim.step(action=np.array([1, 1]))
except PhysicsError:
print("Physics error, ending simulation early")
break
# Record neck actuation for stats at the end of the simulation
if head_stabilization_model is not None:
neck_actuation_pred_hist.append(info["neck_actuation"])
quat = sim.physics.bind(fly.thorax).xquat
quat_inv = transformations.quat_inv(quat)
roll, pitch, _ = transformations.quat_to_euler(quat_inv, ordering="XYZ")
neck_actuation_true_hist.append(np.array([roll, pitch]))
# Record head and thorax orientation
thorax_rotation_quat = sim.physics.bind(thorax_body).xquat
thorax_roll, thorax_pitch, _ = transformations.quat_to_euler(
thorax_rotation_quat, ordering="XYZ"
)
thorax_rotation_hist.append([thorax_roll, thorax_pitch])
head_rotation_quat = sim.physics.bind(head_body).xquat
head_roll, head_pitch, _ = transformations.quat_to_euler(
head_rotation_quat, ordering="XYZ"
)
head_rotation_hist.append([head_roll, head_pitch])
# Generate performance stats on head stabilization
if head_stabilization_model is not None:
neck_actuation_true_hist = np.array(neck_actuation_true_hist)
neck_actuation_pred_hist = np.array(neck_actuation_pred_hist)
r2_scores = {
# exclude the first 200 frames (transient response)
"roll": r2_score(
neck_actuation_true_hist[200:, 0], neck_actuation_pred_hist[200:, 0]
),
"pitch": r2_score(
neck_actuation_true_hist[200:, 1], neck_actuation_pred_hist[200:, 1]
),
}
else:
r2_scores = None
neck_actuation_true_hist = np.array(neck_actuation_true_hist)
neck_actuation_pred_hist = np.zeros_like(neck_actuation_true_hist)
return {
"sim": sim,
"neck_true": neck_actuation_true_hist,
"neck_pred": neck_actuation_pred_hist,
"r2_scores": r2_scores,
"head_rotation_hist": np.array(head_rotation_hist),
"thorax_rotation_hist": np.array(thorax_rotation_hist),
}
To apply the model-predicted neck actuation signals, we have simply
passed the model as the head_stabilization_model
parameter to the
Fly
object. Under the hood, the Fly
object initializes actuators
for the neck roll and pitch DoFs upon __init__
. Then, at each
simulation step, the Fly
class runs the head_stabilization_model
and actuates the appropriate DoFs in addition to the user-specified
actions. In code, this is implemented as follows:
class Fly:
def __init__(... head_stabilization_model ...):
...
# Check neck actuation if head stabilization is enabled
if head_stabilization_model is not None:
if "joint_Head_yaw" in actuated_joints or "joint_Head" in actuated_joints:
raise ValueError(
"The head joints are actuated by a preset algorithm. "
"However, the head joints are already included in the "
"provided Fly instance. Please remove the head joints from "
"the list of actuated joints."
)
self._last_neck_actuation = None # tracked only for head stabilization
...
self.actuated_joints = actuated_joints
self.head_stabilization_model = head_stabilization_model
...
if self.head_stabilization_model is not None:
self.neck_actuators = [
self.model.actuator.add(
self.control,
name=f"actuator_position_{joint}",
joint=joint,
kp=neck_kp,
ctrlrange="-1000000 1000000",
forcelimited=False,
)
for joint in ["joint_Head_yaw", "joint_Head"]
]
...
def pre_step(self, action, sim):
joint_action = action["joints"]
# estimate necessary neck actuation signals for head stabilization
if self.head_stabilization_model is not None:
if self._last_observation is not None:
leg_joint_angles = self._last_observation["joints"][0, :]
leg_contact_forces = self._last_observation["contact_forces"]
neck_actuation = self.head_stabilization_model(
leg_joint_angles, leg_contact_forces
)
else:
neck_actuation = np.zeros(2)
joint_action = np.concatenate((joint_action, neck_actuation))
self._last_neck_actuation = neck_actuation
physics.bind(self.actuators + self.neck_actuators).ctrl = joint_action
def post_step(self, sim):
obs, reward, terminated, truncated, info = ...
...
if self.head_stabilization_model is not None:
# this is tracked to decide neck actuation for the next step
info["neck_actuation"] = self._last_neck_actuation
return obs, reward, terminated, truncated, info
class Simulation:
...
def step(self, action):
...
self.fly.pre_step(action, self)
obs, reward, terminated, truncated, info = self.fly.post_step()
return obs, reward, terminated, truncated, info
Now, we can run the simulation over flat and blocks terrain again:
arena = FlatTerrain()
sim_data_flat = run_simulation_closed_loop(
arena=arena, run_time=1, head_stabilization_model=model_wrapper
)
arena = BlocksTerrain(height_range=(0.2, 0.2))
sim_data_blocks = run_simulation_closed_loop(
arena=arena, run_time=1, head_stabilization_model=model_wrapper
)
100%|██████████| 10000/10000 [00:16<00:00, 594.56it/s]
100%|██████████| 10000/10000 [00:33<00:00, 299.90it/s]
print(f"R² scores over flat terrain: {sim_data_flat['r2_scores']}")
print(f"R² scores over blocks terrain: {sim_data_blocks['r2_scores']}")
R² scores over flat terrain: {'roll': 0.8720892058987814, 'pitch': 0.9293070918490837}
R² scores over blocks terrain: {'roll': 0.5792754921973917, 'pitch': 0.7106359552091986}
Based on these results, we can plot the time series of the model-predicted neck actuation signals and the ideal neck actuation signals:
fig, axs = plt.subplots(2, 1, figsize=(6, 5), tight_layout=True, sharex=True)
color_config = {
"roll": ("royalblue", "midnightblue"),
"pitch": ("peru", "saddlebrown"),
}
for ax, terrain, data in zip(axs, ["Flat", "Blocks"], [sim_data_flat, sim_data_blocks]):
t_grid = np.arange(len(data["neck_true"])) * 1e-4
for i, dof in enumerate(["roll", "pitch"]):
ax.plot(
t_grid,
np.rad2deg(data["neck_true"][:, i]),
label=f"Optimal {dof}",
linestyle="--",
color=color_config[dof][0],
)
ax.plot(
t_grid,
np.rad2deg(data["neck_pred"][:, i]),
label=f"Optimal {dof}",
color=color_config[dof][1],
)
ax.set_title(f"{terrain} terrain")
ax.set_ylabel(r"Target angle ($^\circ$)")
ax.set_ylim(-20, 20)
if terrain == "Flat":
ax.legend(ncols=2)
if terrain == "Blocks":
ax.set_xlabel("Time (s)")
fig.savefig(output_dir / "head_stabilization_neck_actuation_sample.png")
Similarly, we can plot the roll and pitch of the head compared to the thorax over time:
fig, axs = plt.subplots(
2, 2, figsize=(8, 5), tight_layout=True, sharex=True, sharey=True
)
for i, (terrain, data) in enumerate(
zip(["Flat", "Blocks"], [sim_data_flat, sim_data_blocks])
):
for j, dof in enumerate(["roll", "pitch"]):
ax = axs[j, i]
ax.axhline(0, color="black", lw=1)
ax.plot(
t_grid,
np.rad2deg(data["head_rotation_hist"][:, j]),
label="Head",
color="tab:red",
)
ax.plot(
t_grid,
np.rad2deg(data["thorax_rotation_hist"][:, j]),
label="Thorax",
color="tab:blue",
)
ax.set_ylim(-15, 15)
if i == 0 and j == 0:
ax.legend()
if i == 0:
ax.set_ylabel(rf"{dof.capitalize()} angle ($^\circ$)")
if j == 0:
ax.set_title(f"{terrain} terrain")
if j == 1:
ax.set_xlabel("Time (s)")
fig.savefig(output_dir / "head_stabilization_head_vs_thorax.png")
As expected, the rotation of the head has a lower magnitude than that of the body, even over complex terrain.