Vision basics

Author: Sibo Wang-Chen

Note: The code presented in this notebook has been simplified and restructured for display in a notebook format. A more complete and better structured implementation can be found in the examples folder of the FlyGym repository on GitHub.

Notebook Format: This tutorial is available in .ipynb format in the notebooks folder of the FlyGym repository.

Summary: In this tutorial, we will build a simple model to control the fly to follow a moving sphere. By doing so, we will also demonstrate how one can create a custom arena.

Animals typically navigate over rugged terrain to reach attractive objects (e.g. potential mates, food sources) and to avoid repulsive features (e.g. pheromones from predators) and obstacles. Animals use a hierarchical controller to achieve these goals: processing higher-order sensory signals, using them to select the next action, and translating these decisions into descending commands that drive lower-level motor systems. We aimed to simulate this sensorimotor hierarchy by adding vision and olfaction to NeuroMechFly.

Retina simulation

A fly’s compound eye consists of ∼700–750 individual units called ommatidia arranged in a hexagonal pattern (see the left panel of the figure below from the droso4schools project; see also this article from the Arizona Retina Project). To emulate this, we attached a color camera to each of our model’s compound eyes (top right panel). We then transformed each camera image into 721 bins, representing ommatidia. Based on previous studies, we assume a 270° combined azimuth for the fly’s field of view, with a ∼17° binocular overlap. Visual sensitivity has evolved to highlight ethologically relevant color spectra at different locations in the environment. Here, as an initial step toward enabling this heterogeneity in our model, we implemented yellow- and pale-type ommatidia—sensitive to the green and blue channels of images rendered by the physics simulator—randomly assigned at a 7:3 ratio (as reported in Rister et al, 2013). Users can substitute the green and blue channel values with the desired light intensities sensed by yellow- and pale-type ommatidia to achieve more biorealistic chromatic vision.

In NeuroMechFly, the main interface to interact with the fly is the Retina class. It is embedded into the NeuroMechFly class as its .retina attribute if enable_vision is set to True in NeuroMechFly.sim_params. Here, let’s reproduce the top right panel of the figure above by putting the fly into an arena with three pillars and simulating the fly’s visual experience.

To start, we do the necessary imports:

import matplotlib.pyplot as plt
import numpy as np
from tqdm import trange
from gymnasium.utils.env_checker import check_env

from flygym import Fly, Camera
from flygym.arena import FlatTerrain
from import ObstacleOdorArena
from flygym.examples.locomotion import HybridTurningController
from pathlib import Path

output_dir = Path("./outputs/vision_basics")
output_dir.mkdir(exist_ok=True, parents=True)

We have pre-implemented an ObstacleOdorArena class with visual obstacles and an odor source. The details of this class is beyond the scope of this tutorial, but you can refer to it on the FlyGym github repository:

# We start by creating a simple arena
flat_terrain_arena = FlatTerrain()

# Then, we add visual and olfactory features on top of it
arena = ObstacleOdorArena(
    obstacle_positions=np.array([(7.5, 0), (12.5, 5), (17.5, -5)]),
    obstacle_colors=[(0.14, 0.14, 0.2, 1), (0.2, 0.8, 0.2, 1), (0.2, 0.2, 0.8, 1)],
    user_camera_settings=((13, -18, 9), (np.deg2rad(65), 0, 0), 45),

Let’s put the fly in it and simulate 500 timesteps so the fly can stand on the floor in a stable manner:

contact_sensor_placements = [
    for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
    for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]

fly = Fly(
    spawn_pos=(13, -5, 0.2),
    spawn_orientation=(0, 0, np.pi / 2 + np.deg2rad(70)),

cam = Camera(fly=fly, play_speed=0.2, camera_id="user_cam")
sim = HybridTurningController(fly=fly, cameras=[cam], arena=arena)

for i in range(500):
    obs, reward, terminated, truncated, info = sim.step(np.zeros(2))

fig, ax = plt.subplots(figsize=(4, 3), tight_layout=True)
fig.savefig(output_dir / "vision_sim_env.png")

We can access the intensities sensed by the fly’s ommatidia from the observation:

print("Shape:", obs["vision"].shape)
print("Data type:", obs["vision"].dtype)
[[[0.9913793 0.       ]
  [1.        0.       ]
  [1.        0.       ]
  [0.4       0.       ]
  [0.4       0.       ]
  [0.4       0.       ]]

 [[0.9913793 0.       ]
  [1.        0.       ]
  [1.        0.       ]
  [0.4       0.       ]
  [0.4       0.       ]
  [0.4       0.       ]]]
Shape: (2, 721, 2)
Data type: float32

This gives us a (2, 721, 2) array representing the light intensities sensed by the ommatidia. The values are normalized to [0, 1]. The first dimension is for the two eyes (left and right in that order). The second dimension is for the 721 ommatidia per eye. The third dimension is for the two color channels (yellow- and pale-type in that order). For each ommatidia, only one of the two numbers along the last dimension is nonzero. The yellow- and pale-type ommatidia are split at a 7:3 ratio. We can verify this below:

nonzero_idx = np.nonzero(obs["vision"])
unique_vals, val_counts = np.unique(nonzero_idx[2], return_counts=True)
val_counts / val_counts.sum()
array([0.70041609, 0.29958391])

But this is array representation is not good for visualization. We can use the hex_pxls_to_human_readable method of the retina to convert it into a normal [0, 256) 8-bit RGB image that can be plotted. We set color_8bit to True to process the 8-bit color representation more efficiently and to return the output as an integer ranged from 0 to 255. We will further take the grayscale image (disregard yellow- vs pale-type ommatidia) by taking the maximum along the last dimension, i.e. that of color channels.

vision_left = fly.retina.hex_pxls_to_human_readable(
    obs["vision"][0, :, :], color_8bit=True
vision_left = vision_left.max(axis=-1)
vision_right = fly.retina.hex_pxls_to_human_readable(
    obs["vision"][1, :, :], color_8bit=True
vision_right = vision_right.max(axis=-1)

fig, axs = plt.subplots(1, 2, figsize=(6, 3), tight_layout=True)
axs[0].imshow(vision_left, cmap="gray", vmin=0, vmax=255)
axs[0].set_title("Left eye")
axs[1].imshow(vision_right, cmap="gray", vmin=0, vmax=255)
axs[1].set_title("Right eye")
fig.savefig(output_dir / "vision_sim.png")

Since render_raw_vision is set to True in the parameters, we can access the raw RGB vision through the info dictionary before pixels are binned into ommatidia:

fig, axs = plt.subplots(1, 2, figsize=(6, 3), tight_layout=True)
axs[0].imshow(info["raw_vision"][0, :, :, :].astype(np.uint8))
axs[0].set_title("Left eye")
axs[1].imshow(info["raw_vision"][1, :, :, :].astype(np.uint8))
axs[1].set_title("Right eye")
fig.savefig(output_dir / "vision_sim_raw.png")

We observe that the ommatidia covering the blue and the green pillars seem to have a bimodal distribution in intensity. This is because the ommatidia are stochastically split into yellow- and pale-types, and they have different sensitivities to different colors.

A dynamic arena with a moving sphere

The next step is to create a custom arena with a moving sphere. To do this, we will implement a MovingObjArena class that inherits from the flygym.arena.BaseArena class. A complete, functioning implementation of this class is provided under on the FlyGym repository. We start by defining some attributes in its __init__ method:

class MovingObjArena(BaseArena):
    """Flat terrain with a hovering moving object.

    arena : mjcf.RootElement
        The arena object that the terrain is built on.
    ball_pos : Tuple[float,float,float]
        The position of the floating object in the arena.

    size : Tuple[int, int]
        The size of the terrain in (x, y) dimensions.
    friction : Tuple[float, float, float]
        Sliding, torsional, and rolling friction coefficients, by default
        (1, 0.005, 0.0001)
    obj_radius : float
        Radius of the spherical floating object in mm.
    obj_spawn_pos : Tuple[float,float,float]
        Initial position of the object, by default (0, 2, 1).
    move_direction : str
        Which way the ball moves toward first. Can be "left", "right", or
    move_speed : float
        Speed of the moving object.

    def __init__(
        size: Tuple[float, float] = (300, 300),
        friction: Tuple[float, float, float] = (1, 0.005, 0.0001),
        obj_radius: float = 1,
        init_ball_pos: Tuple[float, float] = (5, 0),
        move_speed: float = 8,
        move_direction: str = "right",

        self.init_ball_pos = (*init_ball_pos, obj_radius)
        self.ball_pos = np.array(self.init_ball_pos, dtype="float32")
        self.friction = friction
        self.move_speed = move_speed
        self.curr_time = 0
        self.move_direction = move_direction
        if move_direction == "left":
            self.y_mult = 1
        elif move_direction == "right":
            self.y_mult = -1
        elif move_direction == "random":
            self.y_mult = np.random.choice([-1, 1])
            raise ValueError("Invalid move_direction")


Next, we define a root_element attribute. The simulated world is represented as a tree of objects, each attached to a parent. For example, the eyes of the fly are attached to the head, which is in turn attached to the thorax — the base of the NeuroMechFly model. Note that this tree is merely a representation of objects and their relation to each other; there does not necessarily have to be a visual or anatomical link between the parent and child objects. For example, the base of the NeuroMechFly model — the thorax — is attached to the world, but the link between the thorax and the world is a free joint, meaning that it is free to move in all 6 degrees of freedom. The root element is the root of this tree.


self.root_element = mjcf.RootElement()


Then, we will add the moving sphere. It will be attached to the root element:


# Add ball
obstacle = self.root_element.asset.add(
    "material", name="obstacle", reflectance=0.1
    "body", name="ball_mocap", mocap=True, pos=self.ball_pos, gravcomp=1
self.object_body = self.root_element.find("body", "ball_mocap")
    size=(obj_radius, obj_radius),
    rgba=(0.0, 0.0, 0.0, 1),


Let’s also add some cameras so we can visualize the scene from different angles. This concludes the definition of our __init__ method.


# Add camera
self.birdeye_cam = self.root_element.worldbody.add(
    pos=(15, 0, 35),
    euler=(0, 0, 0),
self.birdeye_cam_zoom = self.root_element.worldbody.add(
    pos=(15, 0, 20),
    euler=(0, 0, 0),

Next, let’s define a get_spawn_position class. This is applies an offset to the user-defined fly spawn position. For example, if there is a stage in your arena that is 1 mm high, and you want to place the fly on this stage, then you might want to apply a transformation to the user-specified relative spawn position and return rel_pos + np.array([0, 0, 1]) as the effective spawn position. In our case, we have a flat arena, so we will just return the spawn position and orientation as is:

def get_spawn_position(self, rel_pos, rel_angle):
    return rel_pos, rel_angle

The arena also has a step method. Usually, in static arenas, this is left empty. However, since we want the sphere to move in our arena, we need to implement this method so the sphere is moved appropriately every step of the simulation:

def step(self, dt, physics):
    heading_vec = np.array([1, 2 * np.cos(self.curr_time * 3) * self.y_mult])
    heading_vec /= np.linalg.norm(heading_vec)
    self.ball_pos[:2] += self.move_speed * heading_vec * dt
    physics.bind(self.object_body).mocap_pos = self.ball_pos
    self.curr_time += dt

Finally, let’s implement a reset method:

def reset(self, physics):
    if self.move_direction == "random":
        self.y_mult = np.random.choice([-1, 1])
    self.curr_time = 0
    self.ball_pos = np.array(self.init_ball_pos, dtype="float32")
    physics.bind(self.object_body).mocap_pos = self.ball_pos

Visual feature preprocessing

We will preprocess the visual feature by computing the x-y position of the object on the retina along with its size relative to the whole retinal image. We do this by applying binary thresholding to the image and calculating its size and center of mass. This is a good example to once again showcase the benefit of encapsulating preprogrammed logic into the Markov Decision Process (implemented as a Gym environment). If you haven’t, read the tutorial on building a turning controller to see how this is done.

Recall that in the HybridTurningController, we implemented the purple arrow in the following figure, encapsulating the CPG network and the sensory feedback-based correction rules:

Here, we will build yet another layer on top of HybridTurningController, implementing the aforementioned sensory preprocessing logic (cyan arrow) and encapsulating it inside the new MDP. As before, a complete, functioning implementation of this class is provided under on the FlyGym repository.

We start by defining an __init__ method. This time, we will specify the threshold used in the binary thresholding step. Any pixel darker than this number will be considered part of the black sphere. We will also define a decision interval \(\tau\): the turning signal is recomputed every \(\tau\) seconds. We will compute the center of mass (COM) of all ommatidia so that later when we need to compute the COM of the object (a masked subset of pixels), we can simply take the average of the COMs of the selected pixels. Finally, we will override the definition of the observation space with a 6-dimensional one (x, y positions of the object per side, plus the relative size of the object per side).

class VisualTaxis(HybridTurningNMF):
    def __init__(self, obj_threshold=0.15, decision_interval=0.05, **kwargs):

        self.obj_threshold = obj_threshold
        self.decision_interval = decision_interval
        self.num_substeps = int(self.decision_interval / self.timestep)
        self.visual_inputs_hist = []

        self.coms = np.empty((self.retina.num_ommatidia_per_eye, 2))
        for i in range(self.retina.num_ommatidia_per_eye):
            mask = self.retina.ommatidia_id_map == i + 1
            self.coms[i, :] = np.argwhere(mask).mean(axis=0)

        self.observation_space = spaces.Box(0, 1, shape=(6,))

Next, let’s implement the visual preprocessing logic discussed above:

def _process_visual_observation(self, raw_obs):
    features = np.zeros((2, 3))
    for i, ommatidia_readings in enumerate(raw_obs["vision"]):
        is_obj = ommatidia_readings.max(axis=1) < self.obj_threshold
        is_obj_coords = self.coms[is_obj]
        if is_obj_coords.shape[0] > 0:
            features[i, :2] = is_obj_coords.mean(axis=0)
        features[i, 2] = is_obj_coords.shape[0]
    features[:, 0] /= self.retina.nrows  # normalize y_center
    features[:, 1] /= self.retina.ncols  # normalize x_center
    features[:, 2] /= self.retina.num_ommatidia_per_eye  # normalize area
    return features.ravel()

In the step method, we will replace the raw observation with the output of the _process_visual_observation method. We will also record the retina images every time the simulation renders a frame for the recorded video. This way, we can visualize the retina images along with the recorded video later:

def step(self, control_signal):
    for _ in range(self.num_substeps):
        raw_obs, _, _, _, _ = super().step(control_signal)
        render_res = super().render()
        if render_res is not None:
            # record visual inputs too because they will be played in the video
    visual_features = self._process_visual_observation(raw_obs)
    return visual_features, 0, False, False, {}

Finally, we implement the reset method:

def reset(self, seed=0, **kwargs):
    raw_obs, _ = super().reset(seed=seed)
    return self._process_visual_observation(raw_obs), {}

Implementing a object tracking controller

Now that we have implemented the arena and the new Gym environment, we just need to define the actual controller logic that outputs the 2D descending representation based on the extracted visual features. As a proof-of-concept, we have hand-tuned the following relationship:

\[\begin{split}\delta_i = \begin{cases} \min(\max(k a_i + b, \delta_\text{min}), \delta_\text{max}) & \text{if } A_i > A_\text{thr} \\ 1 & \text{otherwise} \end{cases}\end{split}\]

where \(\delta_i\) is the descending modulation signal on side \(i\); \(a_i\) is the azimuth of the object expressed as the deviation from the anterior edge of the eye’s field of view, normalized by the horizontal field of view of the retina; \(A_i\) is the relative size of the object on the arena; \(k=-3\), \(b=1\) are parameters describing the response curve; \(\delta_\text{min}=0.2\), \(\delta_\text{max}=1\) describe the range of the descending signal; \(A_\text{thr} = 1\%\) is the threshold below which the object is considered unseen from the eye.

Before we implement this in the main simulation loop, let’s instantiate our arena and Gym environment:

from import MovingObjArena, VisualTaxis

obj_threshold = 0.2
decision_interval = 0.025
contact_sensor_placements = [
    for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
    for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
arena = MovingObjArena()
fly = Fly(
cam = Camera(
    window_size=(800, 608),
sim = VisualTaxis(
    intrinsic_freqs=np.ones(6) * 9,

As before, let’s check if this environment complies with the Gym interface. Despite a few warnings on design choices, no errors should be raised.

.../gymnasium/utils/ UserWarning: WARN: For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See for more information.
.../gymnasium/utils/ UserWarning: WARN: The default seed argument in reset should be None, otherwise the environment will by default always be deterministic. Actual default: 0
.../gymnasium/utils/ UserWarning: WARN: The default seed argument in Env.reset should be None, otherwise the environment will by default always be deterministic. Actual default: seed=0
.../gymnasium/utils/ UserWarning: WARN: Not able to test alternative render modes due to the environment not having a spec. Try instantialising the environment through gymnasium.make

Now, we implement the main simulation loop:

def calc_ipsilateral_speed(deviation, is_found):
    if not is_found:
        return 1.0
        return np.clip(1 - deviation * 3, 0.4, 1.2)

num_substeps = int(decision_interval / sim.timestep)

obs_hist = []
deviations_hist = []
control_signal_hist = []

obs, _ = sim.reset()
for i in trange(140):
    left_deviation = 1 - obs[1]
    right_deviation = obs[4]
    left_found = obs[2] > 0.01
    right_found = obs[5] > 0.01
    if not left_found:
        left_deviation = np.nan
    if not right_found:
        right_deviation = np.nan
    control_signal = np.array(
            calc_ipsilateral_speed(left_deviation, left_found),
            calc_ipsilateral_speed(right_deviation, right_found),

    obs, _, _, _, _ = sim.step(control_signal)
    deviations_hist.append([left_deviation, right_deviation])
100%|██████████| 140/140 [02:05<00:00,  1.11it/s]

To inspect the recorded video:

cam.save_video(output_dir / "object_following.mp4")

We can use the save_video_with_vision_insets utility function to regenerate this video, but with insets at the bottom illustrating the visual experience of the fly:

from import save_video_with_vision_insets

    output_dir / "object_following_with_retina_images.mp4",