Skip to content

simulation

Simulation

CPU-based single-world physics simulation.

Wraps a compiled MuJoCo model and provides methods for stepping physics, reading state, and writing control inputs.

Parameters:

Name Type Description Default
world BaseWorld

A fully configured world with at least one fly attached.

required
timestep float | None

Physics timestep in seconds. If None, the model's compiled-in timestep (from mujoco_globals.yaml) is used.

None

Attributes:

Name Type Description
world

The world used to construct this simulation.

renderer

The attached Renderer, or None if not set.

mj_model

Compiled MuJoCo model.

mj_data

Associated MuJoCo data.

Source code in src/flygym/simulation.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
class Simulation:
    """CPU-based single-world physics simulation.

    Wraps a compiled MuJoCo model and provides methods for stepping physics,
    reading state, and writing control inputs.

    Args:
        world: A fully configured world with at least one fly attached.
        timestep: Physics timestep in seconds. If None, the model's compiled-in
            timestep (from ``mujoco_globals.yaml``) is used.

    Attributes:
        world: The world used to construct this simulation.
        renderer: The attached `Renderer`, or None if not set.
        mj_model: Compiled MuJoCo model.
        mj_data: Associated MuJoCo data.
    """

    def __init__(self, world: BaseWorld, *, timestep: float | None = None) -> None:
        if len(world.fly_lookup) == 0:
            raise ValueError("The world must contain at least one fly.")
        self.renderer = None
        self.world = world
        self.mj_model, self.mj_data = world.compile()
        if timestep is not None:
            self.mj_model.opt.timestep = timestep
        self._neutral_keyframe_id = mj.mj_name2id(
            self.mj_model, mj.mjtObj.mjOBJ_KEY, "neutral"
        )
        mj.mj_resetDataKeyframe(self.mj_model, self.mj_data, self._neutral_keyframe_id)

        # Map internal IDs in the compiled MuJoCo model. This allows users to read from
        # or write to body/joint/actuator in orders defined by Fly objects.
        self._map_internal_bodyids()
        self._map_internal_geom_ids()
        self._map_internal_ground_geom_ids()
        self._map_internal_qposqveladrs()
        self._map_internal_actuator_ids()
        self._map_internal_adhesionactuator_ids()
        self._map_internal_tendonactuator_ids()
        self._map_internal_jointids()
        self._map_internal_groundcontactsensor_ids()
        self._map_internal_site_ids()
        self._map_internal_eye_camera_ids()

        self.eye_renderer = None
        self.retina = None

        # For performance profiling
        self._curr_step = 0
        self._frames_rendered = 0
        self._total_physics_time_ns = 0
        self._total_render_time_ns = 0

    def reset(self) -> None:
        """Reset simulation and renderer to the neutral keyframe."""
        # Reset physics
        mj.mj_resetDataKeyframe(self.mj_model, self.mj_data, self._neutral_keyframe_id)

        # Reset renderers
        if self.renderer is not None:
            self.renderer.reset()
        # The eye renderer doesn't have to be reset as it's stateless (it's the plain
        # MuJoCo renderer, not our flygym.rendering.Renderer)

        # Stuff for performance profiling
        self._curr_step = 0
        self._frames_rendered = 0
        self._total_physics_time_ns = 0
        self._total_render_time_ns = 0

    def step(self) -> None:
        """Advance physics by one timestep."""
        mj.mj_step(self.mj_model, self.mj_data)

    def step_with_profile(self) -> None:
        """Advance physics by one timestep, accumulating timing data for profiling."""
        physics_start_ns = perf_counter_ns()
        self.step()
        physics_finish_ns = perf_counter_ns()
        self._total_physics_time_ns += physics_finish_ns - physics_start_ns
        self._curr_step += 1

    def set_renderer(
        self,
        cameras: str | mj.MjsCamera | list[str | mj.MjsCamera],
        *,
        camera_res: tuple[int, int] = (240, 320),
        playback_speed: float = 0.2,
        output_fps: int = 25,
        buffer_frames: bool = True,
        scene_option: mj.MjvOption | None = None,
        **kwargs: Any,
    ) -> Renderer:
        """Attach a renderer to this simulation.

        Args:
            cameras: Camera(s) to render. Can be a camera name, MJCF camera element,
                or a sequence of either.
            camera_res: ``(height, width)`` in pixels.
            playback_speed: Video playback speed relative to real time.
            output_fps: Output video frame rate.
            buffer_frames: If True, store rendered frames in memory.
            scene_option: MuJoCo scene options. Uses defaults if None.
            **kwargs: Passed to ``mujoco.Renderer``.

        Returns:
            The created `Renderer` instance.
        """
        self.renderer = Renderer(
            self.mj_model,
            cameras,
            camera_res=camera_res,
            playback_speed=playback_speed,
            output_fps=output_fps,
            buffer_frames=buffer_frames,
            scene_option=scene_option,
            **kwargs,
        )
        return self.renderer

    def render_as_needed(self) -> bool:
        """Render a frame if enough simulation time has elapsed since the last render.

        Returns:
            True if a frame was rendered, False otherwise.
        """
        return self.renderer.render_as_needed(self.mj_data)

    def render_as_needed_with_profile(self) -> bool:
        """Like `render_as_needed`, but also accumulates render timing data."""
        render_start_ns = perf_counter_ns()
        render_done = self.render_as_needed()
        render_finish_ns = perf_counter_ns()
        self._total_render_time_ns += render_finish_ns - render_start_ns
        if render_done:
            self._frames_rendered += 1
        return render_done

    def get_joint_angles(self, fly_name: str) -> Float[np.ndarray, "n_jointdofs"]:  # noqa: F821
        """Get current joint angles ordered by the fly's skeleton.

        Args:
            fly_name: Name of the fly.

        Returns:
            Joint angles in radians, shape ``(n_jointdofs,)``, ordered as in
            ``fly.get_jointdofs_order()``.
        """
        internal_ids = self._intern_qposadrs_by_fly[fly_name]
        return self.mj_data.qpos[internal_ids]

    def get_joint_velocities(self, fly_name: str) -> Float[np.ndarray, "n_jointdofs"]:  # noqa: F821
        """Get current joint angular velocities ordered by the fly's skeleton.

        Args:
            fly_name: Name of the fly.

        Returns:
            Joint velocities in radians per second, shape ``(n_jointdofs,)``, ordered
            as in ``fly.get_jointdofs_order()``.
        """
        internal_ids = self._intern_qveladrs_by_fly[fly_name]
        return self.mj_data.qvel[internal_ids]

    def get_body_positions(self, fly_name: str) -> Float[np.ndarray, "n_bodies 3"]:
        """Get global 3D positions of all body segments.

        Args:
            fly_name: Name of the fly.

        Returns:
            Body positions in mm, shape ``(n_bodies, 3)``, ordered as in
            ``fly.get_bodysegs_order()``.
        """
        internal_ids = self._internal_bodyids_by_fly[fly_name]
        return self.mj_data.xpos[internal_ids, :]

    def get_body_rotations(self, fly_name: str) -> Float[np.ndarray, "n_bodies 4"]:
        """Get global orientations of all body segments as quaternions (w, x, y, z).

        Args:
            fly_name: Name of the fly.

        Returns:
            Body quaternions, shape ``(n_bodies, 4)``, ordered as in
            ``fly.get_bodysegs_order()``.
        """
        internal_ids = self._internal_bodyids_by_fly[fly_name]
        return self.mj_data.xquat[internal_ids, :]

    def get_actuator_forces(
        self, fly_name: str, actuator_type: ActuatorType
    ) -> Float[np.ndarray, "n_actuators"]:  # noqa: F821
        """Get actuator forces for the given actuator type.

        Args:
            fly_name: Name of the fly.
            actuator_type: Type of actuator to query.

        Returns:
            Actuator forces, shape ``(n_actuators,)``, ordered as in
            ``fly.get_actuated_jointdofs_order(actuator_type)``.
        """
        internal_ids = self._intern_actuatorids_by_type_by_fly[actuator_type][fly_name]
        return self.mj_data.actuator_force[internal_ids]

    def get_ground_contact_info(
        self, fly_name: str
    ) -> tuple[
        Float[np.ndarray, "6"],  # contact/no contact flag
        Float[np.ndarray, "6 3"],  # force (in contact frame)
        Float[np.ndarray, "6 3"],  # torque (in contact frame)
        Float[np.ndarray, "6 3"],  # pos (in global frame)
        Float[np.ndarray, "6 3"],  # normal (in global frame)
        Float[np.ndarray, "6 3"],  # tangent (in global frame)
    ]:
        """Get ground contact information for all six legs.

        Args:
            fly_name: Name of the fly.

        Returns:
            A 6-tuple, one entry per leg ordered as in ``fly.get_legs_order()``:

            - ``contact_found``: shape ``(6,)`` — raw ``found`` channel from the
                MuJoCo contact sensor.
            - ``forces``: shape ``(6, 3)`` — contact force in contact frame.
            - ``torques``: shape ``(6, 3)`` — contact torque in contact frame.
            - ``positions``: shape ``(6, 3)`` — contact position in global frame.
            - ``normals``: shape ``(6, 3)`` — contact normal in global frame.
            - ``tangents``: shape ``(6, 3)`` — contact tangent in global frame.
        """
        internal_ids = self._intern_groundcontactsensorids_by_fly[fly_name]
        sensor_data = self.mj_data.sensordata[internal_ids]
        # Reshape (6 legs * 16 dims per sensor,) to (6 legs, 16 dim per sensor)
        sensor_data = sensor_data.reshape(6, 16)
        contact_found = sensor_data[:, 0]
        forces = sensor_data[:, 1:4]
        torques = sensor_data[:, 4:7]
        positions = sensor_data[:, 7:10]
        normals = sensor_data[:, 10:13]
        tangents = sensor_data[:, 13:]
        return contact_found, forces, torques, positions, normals, tangents

    def get_bodysegment_contact_forces(
        self,
        fly_name: str,
        body_segments: list[BodySegment | str],
        *,
        ground_only: bool = True,
    ) -> Float[np.ndarray, "n_bodysegments 3"]:
        """Get net world-frame contact forces on selected body segments.

        Args:
            fly_name: Name of the fly.
            body_segments: Body segments to query, ordered as desired in the output.
            ground_only: If True, include only contacts with world ground geoms.

        Returns:
            Net force vectors in MuJoCo world coordinates, one row per requested body
            segment.
        """
        requested_segments = [
            seg if isinstance(seg, BodySegment) else BodySegment(seg)
            for seg in body_segments
        ]
        geom_ids_by_segment = self._internal_geomid_by_bodyseg_by_fly[fly_name]
        requested_geom_to_output = {
            geom_ids_by_segment[seg]: i for i, seg in enumerate(requested_segments)
        }
        forces = np.zeros((len(requested_segments), 3), dtype=float)

        ncon = self.mj_data.ncon
        if ncon == 0:
            return forces

        # Vectorised filtering: find relevant contact indices without a Python loop.
        contacts = self.mj_data.contact
        geom1_arr = contacts.geom1[:ncon]
        geom2_arr = contacts.geom2[:ncon]
        exclude_arr = contacts.exclude[:ncon].astype(bool)

        requested_geom_arr = np.array(
            list(requested_geom_to_output.keys()), dtype=np.int32
        )
        geom1_requested = np.isin(geom1_arr, requested_geom_arr)
        geom2_requested = np.isin(geom2_arr, requested_geom_arr)
        active = (geom1_requested | geom2_requested) & ~exclude_arr

        if ground_only:
            ground_arr = self._internal_ground_geom_ids
            geom1_is_ground = np.isin(geom1_arr, ground_arr)
            geom2_is_ground = np.isin(geom2_arr, ground_arr)
            active &= (geom1_requested & geom2_is_ground) | (
                geom2_requested & geom1_is_ground
            )

        contact_wrench = np.zeros(6, dtype=float)
        for contact_id in np.where(active)[0]:
            mj.mj_contactForce(
                self.mj_model, self.mj_data, int(contact_id), contact_wrench
            )
            frame = contacts.frame[contact_id].reshape(3, 3)
            world_force = frame.T @ contact_wrench[:3]

            g1, g2 = int(geom1_arr[contact_id]), int(geom2_arr[contact_id])
            if g1 in requested_geom_to_output:
                forces[requested_geom_to_output[g1]] -= world_force
            if g2 in requested_geom_to_output:
                forces[requested_geom_to_output[g2]] += world_force

        return forces

    def get_site_positions(self, fly_name: str) -> Float[np.ndarray, "n_sites 3"]:
        """Get global 3D positions of anatomical-joint sites.

        Args:
            fly_name: Name of the fly.

        Returns:
            Site positions in mm, shape ``(n_sites, 3)``, ordered as in
            ``fly.get_sites_order()``.
        """
        internal_ids = self._internal_siteids_by_fly[fly_name]
        return self.mj_data.site_xpos[internal_ids, :]

    def set_actuator_inputs(
        self,
        fly_name: str,
        actuator_type: ActuatorType,
        inputs: Float[np.ndarray, "n_actuators"],  # noqa: F821
    ) -> None:
        """Set control inputs for the given actuator type.

        Args:
            fly_name: Name of the fly.
            actuator_type: Type of actuator to control.
            inputs: Control inputs, shape ``(n_actuators,)``, ordered as in
                ``fly.get_actuated_jointdofs_order(actuator_type)``.
        """
        internal_ids = self._intern_actuatorids_by_type_by_fly[actuator_type][fly_name]
        if len(inputs) != len(internal_ids):
            raise ValueError(
                f"Expected {len(internal_ids)} inputs for actuator type "
                f"'{actuator_type.name}', but got {len(inputs)}"
            )
        self.mj_data.ctrl[internal_ids] = inputs

    def set_leg_adhesion_states(
        self, fly_name: str, leg_to_adhesion_state: Float[np.ndarray, "6"]
    ) -> None:
        """Set adhesion states for each leg.

        Args:
            fly_name: Name of the fly.
            leg_to_adhesion_state: Adhesion control per leg, shape ``(6,)``, ordered as
                in ``fly.get_legs_order()``. Values should be in the range ``[0, 1]``.
        """
        internal_ids = self._intern_adhesionactuatorids_by_fly[fly_name]
        if len(leg_to_adhesion_state) != len(internal_ids):
            raise ValueError(
                "Unexpected number of adhesion states: "
                f"expected {len(internal_ids)}, got {len(leg_to_adhesion_state)}"
            )
        self.mj_data.ctrl[internal_ids] = leg_to_adhesion_state

    def set_tendon_actuator_inputs(
        self,
        fly_name: str,
        inputs: Float[np.ndarray, "n_tendon_actuators"],  # noqa: F821
    ) -> None:
        """Set control inputs for tendon actuators.

        Args:
            fly_name: Name of the fly.
            inputs: Control inputs, shape ``(n_tendon_actuators,)``, ordered as in
                ``fly.get_actuated_jointdofs_order(ActuatorType.TENDON)``.
        """
        # Flies without tendon actuators have no entry in the lookup, so default to
        # an empty id array: the length check below then turns an empty input into a
        # clean no-op and a non-empty one into a clear "expected 0" error.
        internal_ids = self._intern_tendonactuatorids_by_fly.get(
            fly_name, np.empty(0, dtype=np.int32)
        )
        if len(inputs) != len(internal_ids):
            raise ValueError(
                f"Expected {len(internal_ids)} tendon actuator inputs, but got "
                f"{len(inputs)}"
            )
        self.mj_data.ctrl[internal_ids] = inputs

    def get_raw_vision(self, fly_name: str) -> Float[np.ndarray, "2 height width 3"]:
        """Render the fly's eye cameras and return fisheye-corrected frames.

        Certain body parts are invisible to the eye cameras to avoid self-occlusion, as
        configured in `flygym/assets/model/neuromechfly/vision.yaml`. These geoms are assigned to
        geom group 2, which _is_ rendered by the MuJoCo renderer by default, but the eye
        renderer within FlyGym is configured to ignore this geom group.

        Args:
            fly_name: Name of the fly to query.

        Returns:
            An array of shape (2, height, width, 3) containing the RGB images from the
            fly's two eyes. The first dimension corresponds to the left and right eye,
            in that order.
        """
        try:
            internal_eye_camera_ids = self._intern_eye_camera_ids_by_fly[fly_name]
        except KeyError:
            raise ValueError(
                f"Fly '{fly_name}' does not have any eye cameras defined. "
                "Make sure to call fly.add_vision() when constructing the fly."
            )

        # Lazy-construct Retina and eye renderer only if user queries visual input
        if self.retina is None:
            from flygym.vision.retina import Retina

            self.retina = Retina()

        if self.eye_renderer is None:
            self.eye_renderer = mj.Renderer(
                self.mj_model,
                height=self.retina.nrows,
                width=self.retina.ncols,
            )
            # Make eye renderer apply option to ignore geoms in group 2, which includes
            # body segments that should not be rendered by the eye cameras to avoid
            # self-occlusion. Disable group 1 as well because markers for eye positions
            # belong to group 1.
            self.eye_renderer_scene_option = mj.MjvOption()
            self.eye_renderer_scene_option.geomgroup[1] = 0
            self.eye_renderer_scene_option.geomgroup[2] = 0

        # Render each eye camera and apply fisheye correction
        frames = []
        for cam_id in internal_eye_camera_ids:
            self.eye_renderer.update_scene(
                self.mj_data, cam_id, scene_option=self.eye_renderer_scene_option
            )
            raw_frame = self.eye_renderer.render()
            fish_img = self.retina.correct_fisheye(raw_frame)
            frames.append(fish_img)
        return np.array(frames)

    def get_ommatidia_readouts(
        self, fly_name: str
    ) -> Float[np.ndarray, "n_cameras n_ommatidia 2"]:
        """Convert the rendered eye frames into ommatidia readouts.

        Args:
            fly_name: Name of the fly to query.

        Returns:
            A float32 array with shape ``(2, n_ommatidia, 2)`` containing
            the yellow/pale channel readings for each eye camera. The first dimension
            corresponds to the left and right eyes, in that order. The last
            dimension corresponds to the yellow- and pale-type ommatidia, in that
            order. Zero values indicate that the ommatidium is of the other type.
            For example, if `readouts[0, 5, 0]` is 0, it means that the 5th ommatidium
            is of pale type, and the user should look at `readouts[0, 5, 1]` instead.
        """
        raw_vision = self.get_raw_vision(fly_name)
        ommatidia_readouts = np.array(
            [self.retina.raw_image_to_hex_pxls(image) for image in raw_vision],
            dtype=np.float32,
        )
        return ommatidia_readouts

    def warmup(self, duration_s: float = 0.05) -> None:
        """Step the simulation for a short period to settle initialization transients.

        Call after `reset` and before the main simulation loop to allow the fly to
        settle onto the ground.

        Args:
            duration_s: Duration of the warmup period in seconds.
        """
        n_steps = int(duration_s / self.mj_model.opt.timestep)
        for _ in range(n_steps):
            self.step()

    def _map_internal_bodyids(self) -> None:
        internal_bodyids_by_fly = defaultdict(list)

        for fly_name, fly in self.world.fly_lookup.items():
            for bodyseg, mjcf_body_element in fly.bodyseg_to_mjcfbody.items():
                internal_body_id = mj.mj_name2id(
                    self.mj_model,
                    mj.mjtObj.mjOBJ_BODY,
                    mjcf_body_element.name,
                )
                internal_bodyids_by_fly[fly_name].append(internal_body_id)

        self._internal_bodyids_by_fly = {
            k: np.array(v, dtype=np.int32) for k, v in internal_bodyids_by_fly.items()
        }

    def _map_internal_geom_ids(self) -> None:
        internal_geomids_by_bodyseg_by_fly = {}

        for fly_name, fly in self.world.fly_lookup.items():
            internal_geomids_by_bodyseg_by_fly[fly_name] = {}
            for bodyseg, mjcf_geom_elements in fly.bodyseg_to_mjcfgeom.items():
                for mjcf_geom_element in mjcf_geom_elements:
                    internal_geom_id = mj.mj_name2id(
                        self.mj_model,
                        mj.mjtObj.mjOBJ_GEOM,
                        mjcf_geom_element.name,
                    )
                    internal_geomids_by_bodyseg_by_fly[fly_name][bodyseg] = (
                        internal_geom_id
                    )

        self._internal_geomid_by_bodyseg_by_fly = internal_geomids_by_bodyseg_by_fly

    def _map_internal_ground_geom_ids(self) -> None:
        internal_ground_geom_ids = []
        for ground_geom in getattr(self.world, "ground_geoms", []):
            internal_ground_geom_ids.append(
                mj.mj_name2id(
                    self.mj_model,
                    mj.mjtObj.mjOBJ_GEOM,
                    ground_geom.name,
                )
            )
        self._internal_ground_geom_ids = np.array(
            internal_ground_geom_ids, dtype=np.int32
        )

    def _map_internal_jointids(self) -> None:
        internal_jointids_by_fly = defaultdict(list)

        for fly_name, fly in self.world.fly_lookup.items():
            for jointdof, mjcf_joint_element in fly.jointdof_to_mjcfjoint.items():
                internal_joint_id = mj.mj_name2id(
                    self.mj_model,
                    mj.mjtObj.mjOBJ_JOINT,
                    mjcf_joint_element.name,
                )
                internal_jointids_by_fly[fly_name].append(internal_joint_id)

        self._internal_jointids_by_fly = {
            k: np.array(v, dtype=np.int32) for k, v in internal_jointids_by_fly.items()
        }

    def _map_internal_qposqveladrs(self) -> None:
        internal_qposadrs_by_fly = defaultdict(list)
        internal_qveladrs_by_fly = defaultdict(list)

        for fly_name, fly in self.world.fly_lookup.items():
            for jointdof, mjcf_joint_element in fly.jointdof_to_mjcfjoint.items():
                internal_joint_id = mj.mj_name2id(
                    self.mj_model,
                    mj.mjtObj.mjOBJ_JOINT,
                    mjcf_joint_element.name,
                )
                qposadr = self.mj_model.jnt_qposadr[internal_joint_id]
                qveladr = self.mj_model.jnt_dofadr[internal_joint_id]
                internal_qposadrs_by_fly[fly_name].append(qposadr)
                internal_qveladrs_by_fly[fly_name].append(qveladr)

        self._intern_qposadrs_by_fly = {
            k: np.array(v, dtype=np.int32) for k, v in internal_qposadrs_by_fly.items()
        }
        self._intern_qveladrs_by_fly = {
            k: np.array(v, dtype=np.int32) for k, v in internal_qveladrs_by_fly.items()
        }

    def _map_internal_actuator_ids(self) -> None:
        internal_actuatorids_by_fly_by_type = defaultdict(lambda: defaultdict(list))

        for fly_name, fly in self.world.fly_lookup.items():
            for actuator_ty, actuators in fly.jointdof_to_mjcfactuator_by_type.items():
                for jointdof, actuator_element in actuators.items():
                    internal_actuator_id = mj.mj_name2id(
                        self.mj_model,
                        mj.mjtObj.mjOBJ_ACTUATOR,
                        actuator_element.name,
                    )
                    internal_actuatorids_by_fly_by_type[actuator_ty][fly_name].append(
                        internal_actuator_id
                    )

        self._intern_actuatorids_by_type_by_fly = {
            actuator_ty: {
                fly_name: np.array(ids, dtype=np.int32)
                for fly_name, ids in ids_by_fly.items()
            }
            for actuator_ty, ids_by_fly in internal_actuatorids_by_fly_by_type.items()
        }

    def _map_internal_tendonactuator_ids(self) -> None:
        internal_tendonactuatorids_by_fly = defaultdict(list)
        for fly_name, fly in self.world.fly_lookup.items():
            if len(fly.jointdof_to_mjcfactuator_by_type[ActuatorType.TENDON]) == 0:
                continue  # This fly doesn't have any tendon actuators
            for jointdof, actuator_element in fly.jointdof_to_mjcfactuator_by_type[
                ActuatorType.TENDON
            ].items():
                internal_actuator_id = mj.mj_name2id(
                    self.mj_model,
                    mj.mjtObj.mjOBJ_ACTUATOR,
                    actuator_element.name,
                )
                internal_tendonactuatorids_by_fly[fly_name].append(internal_actuator_id)
        self._intern_tendonactuatorids_by_fly = {
            fly_name: np.array(ids, dtype=np.int32)
            for fly_name, ids in internal_tendonactuatorids_by_fly.items()
        }

    def _map_internal_adhesionactuator_ids(self) -> None:
        internal_adhesionactuatorids_by_fly = defaultdict(list)
        for fly_name, fly in self.world.fly_lookup.items():
            if len(fly.leg_to_adhesionactuator) == 0:
                continue  # This fly doesn't have leg adhesion actuators
            for leg in fly.get_legs_order():
                actuator_element = fly.leg_to_adhesionactuator[leg]
                internal_actuator_id = mj.mj_name2id(
                    self.mj_model,
                    mj.mjtObj.mjOBJ_ACTUATOR,
                    actuator_element.name,
                )
                internal_adhesionactuatorids_by_fly[fly_name].append(
                    internal_actuator_id
                )
        self._intern_adhesionactuatorids_by_fly = {
            fly_name: np.array(ids, dtype=np.int32)
            for fly_name, ids in internal_adhesionactuatorids_by_fly.items()
        }

    def _map_internal_groundcontactsensor_ids(self) -> None:
        if self.world.legpos_to_groundcontactsensors_by_fly is None:
            self._intern_groundcontactsensorids_by_fly = None
            return

        self._intern_groundcontactsensorids_by_fly = {}

        for fly_name, fly in self.world.fly_lookup.items():
            indices_thisfly = []
            sensors_by_leg = self.world.legpos_to_groundcontactsensors_by_fly.get(
                fly_name, {}
            )
            for leg in fly.get_legs_order():
                sensor = sensors_by_leg.get(leg)
                if sensor is None:
                    continue
                internal_id = mj.mj_name2id(
                    self.mj_model, mj.mjtObj.mjOBJ_SENSOR, sensor.name
                )
                start_idx = self.mj_model.sensor_adr[internal_id]
                sensor_dim = self.mj_model.sensor_dim[internal_id]
                # Sensor should be 16-dim: found (1), force (3), torque (3), pos (3),
                # normal (3), tangent (3)
                assert sensor_dim == 16, "unexpected ground contact sensor dimension"
                indices_thisfly.extend(list(range(start_idx, start_idx + sensor_dim)))
            indices_arr = np.array(indices_thisfly, dtype=np.int32)
            self._intern_groundcontactsensorids_by_fly[fly_name] = indices_arr

    def _map_internal_site_ids(self) -> None:
        internal_siteids_by_fly = {
            fly_name: [] for fly_name in self.world.fly_lookup.keys()
        }

        for fly_name, fly in self.world.fly_lookup.items():
            for _, mjcf_site_element in fly.anatomicaljoint_to_mjcfsites.items():
                internal_site_id = mj.mj_name2id(
                    self.mj_model,
                    mj.mjtObj.mjOBJ_SITE,
                    mjcf_site_element.name,
                )
                internal_siteids_by_fly[fly_name].append(internal_site_id)

        self._internal_siteids_by_fly = {
            k: np.array(v, dtype=np.int32) for k, v in internal_siteids_by_fly.items()
        }

    def _map_internal_eye_camera_ids(self):
        internal_eye_camera_ids_by_fly = defaultdict(list)

        for fly_name, fly in self.world.fly_lookup.items():
            for eye_camera_element in fly.eyecameraname_to_mjcfcamera.values():
                internal_eye_camera_id = mj.mj_name2id(
                    self.mj_model,
                    mj.mjtObj.mjOBJ_CAMERA,
                    eye_camera_element.name,
                )
                internal_eye_camera_ids_by_fly[fly_name].append(internal_eye_camera_id)

        self._intern_eye_camera_ids_by_fly = {
            k: np.array(v, dtype=np.int32)
            for k, v in internal_eye_camera_ids_by_fly.items()
        }

    @property
    def time(self) -> float:
        """Current simulation time in seconds."""
        return self.mj_data.time

    def print_performance_report(
        self, show_in_notebook: bool | Literal["auto"] = "auto"
    ) -> None:
        """Print a summary of physics and rendering performance.

        Requires that `step_with_profile` and `render_as_needed_with_profile` were
        used during the simulation loop.

        Args:
            show_in_notebook: If True, render the report as an HTML table suitable for
                display in a Jupyter notebook. If "auto", will attempt to detect if
                we're in a notebook environment and choose accordingly.
        """
        print_perf_report(
            n_steps=self._curr_step,
            n_frames_rendered=self._frames_rendered,
            total_physics_time_ns=self._total_physics_time_ns,
            total_render_time_ns=self._total_render_time_ns,
            timestep=self.timestep,
            show_in_notebook=show_in_notebook,
        )

    @property
    def timestep(self) -> float:
        """Simulation timestep in seconds."""
        return self.mj_model.opt.timestep

    def close(self):
        """Clean up resources allocated by the simulation.

        This method is idempotent (safe to call multiple times).
        """

        # Use getattr to handle cases where attributes may not exist
        renderer = getattr(self, "renderer", None)
        if renderer is not None:
            renderer.close()
        eye_renderer = getattr(self, "eye_renderer", None)
        if eye_renderer is not None:
            eye_renderer.close()

        # Clear references to help GC and make close idempotent
        self.renderer = None
        self.eye_renderer = None

time property

Current simulation time in seconds.

timestep property

Simulation timestep in seconds.

close()

Clean up resources allocated by the simulation.

This method is idempotent (safe to call multiple times).

Source code in src/flygym/simulation.py
def close(self):
    """Clean up resources allocated by the simulation.

    This method is idempotent (safe to call multiple times).
    """

    # Use getattr to handle cases where attributes may not exist
    renderer = getattr(self, "renderer", None)
    if renderer is not None:
        renderer.close()
    eye_renderer = getattr(self, "eye_renderer", None)
    if eye_renderer is not None:
        eye_renderer.close()

    # Clear references to help GC and make close idempotent
    self.renderer = None
    self.eye_renderer = None

get_actuator_forces(fly_name, actuator_type)

Get actuator forces for the given actuator type.

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required
actuator_type ActuatorType

Type of actuator to query.

required

Returns:

Type Description
Float[ndarray, n_actuators]

Actuator forces, shape (n_actuators,), ordered as in

Float[ndarray, n_actuators]

fly.get_actuated_jointdofs_order(actuator_type).

Source code in src/flygym/simulation.py
def get_actuator_forces(
    self, fly_name: str, actuator_type: ActuatorType
) -> Float[np.ndarray, "n_actuators"]:  # noqa: F821
    """Get actuator forces for the given actuator type.

    Args:
        fly_name: Name of the fly.
        actuator_type: Type of actuator to query.

    Returns:
        Actuator forces, shape ``(n_actuators,)``, ordered as in
        ``fly.get_actuated_jointdofs_order(actuator_type)``.
    """
    internal_ids = self._intern_actuatorids_by_type_by_fly[actuator_type][fly_name]
    return self.mj_data.actuator_force[internal_ids]

get_body_positions(fly_name)

Get global 3D positions of all body segments.

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required

Returns:

Type Description
Float[ndarray, 'n_bodies 3']

Body positions in mm, shape (n_bodies, 3), ordered as in

Float[ndarray, 'n_bodies 3']

fly.get_bodysegs_order().

Source code in src/flygym/simulation.py
def get_body_positions(self, fly_name: str) -> Float[np.ndarray, "n_bodies 3"]:
    """Get global 3D positions of all body segments.

    Args:
        fly_name: Name of the fly.

    Returns:
        Body positions in mm, shape ``(n_bodies, 3)``, ordered as in
        ``fly.get_bodysegs_order()``.
    """
    internal_ids = self._internal_bodyids_by_fly[fly_name]
    return self.mj_data.xpos[internal_ids, :]

get_body_rotations(fly_name)

Get global orientations of all body segments as quaternions (w, x, y, z).

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required

Returns:

Type Description
Float[ndarray, 'n_bodies 4']

Body quaternions, shape (n_bodies, 4), ordered as in

Float[ndarray, 'n_bodies 4']

fly.get_bodysegs_order().

Source code in src/flygym/simulation.py
def get_body_rotations(self, fly_name: str) -> Float[np.ndarray, "n_bodies 4"]:
    """Get global orientations of all body segments as quaternions (w, x, y, z).

    Args:
        fly_name: Name of the fly.

    Returns:
        Body quaternions, shape ``(n_bodies, 4)``, ordered as in
        ``fly.get_bodysegs_order()``.
    """
    internal_ids = self._internal_bodyids_by_fly[fly_name]
    return self.mj_data.xquat[internal_ids, :]

get_bodysegment_contact_forces(fly_name, body_segments, *, ground_only=True)

Get net world-frame contact forces on selected body segments.

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required
body_segments list[BodySegment | str]

Body segments to query, ordered as desired in the output.

required
ground_only bool

If True, include only contacts with world ground geoms.

True

Returns:

Type Description
Float[ndarray, 'n_bodysegments 3']

Net force vectors in MuJoCo world coordinates, one row per requested body

Float[ndarray, 'n_bodysegments 3']

segment.

Source code in src/flygym/simulation.py
def get_bodysegment_contact_forces(
    self,
    fly_name: str,
    body_segments: list[BodySegment | str],
    *,
    ground_only: bool = True,
) -> Float[np.ndarray, "n_bodysegments 3"]:
    """Get net world-frame contact forces on selected body segments.

    Args:
        fly_name: Name of the fly.
        body_segments: Body segments to query, ordered as desired in the output.
        ground_only: If True, include only contacts with world ground geoms.

    Returns:
        Net force vectors in MuJoCo world coordinates, one row per requested body
        segment.
    """
    requested_segments = [
        seg if isinstance(seg, BodySegment) else BodySegment(seg)
        for seg in body_segments
    ]
    geom_ids_by_segment = self._internal_geomid_by_bodyseg_by_fly[fly_name]
    requested_geom_to_output = {
        geom_ids_by_segment[seg]: i for i, seg in enumerate(requested_segments)
    }
    forces = np.zeros((len(requested_segments), 3), dtype=float)

    ncon = self.mj_data.ncon
    if ncon == 0:
        return forces

    # Vectorised filtering: find relevant contact indices without a Python loop.
    contacts = self.mj_data.contact
    geom1_arr = contacts.geom1[:ncon]
    geom2_arr = contacts.geom2[:ncon]
    exclude_arr = contacts.exclude[:ncon].astype(bool)

    requested_geom_arr = np.array(
        list(requested_geom_to_output.keys()), dtype=np.int32
    )
    geom1_requested = np.isin(geom1_arr, requested_geom_arr)
    geom2_requested = np.isin(geom2_arr, requested_geom_arr)
    active = (geom1_requested | geom2_requested) & ~exclude_arr

    if ground_only:
        ground_arr = self._internal_ground_geom_ids
        geom1_is_ground = np.isin(geom1_arr, ground_arr)
        geom2_is_ground = np.isin(geom2_arr, ground_arr)
        active &= (geom1_requested & geom2_is_ground) | (
            geom2_requested & geom1_is_ground
        )

    contact_wrench = np.zeros(6, dtype=float)
    for contact_id in np.where(active)[0]:
        mj.mj_contactForce(
            self.mj_model, self.mj_data, int(contact_id), contact_wrench
        )
        frame = contacts.frame[contact_id].reshape(3, 3)
        world_force = frame.T @ contact_wrench[:3]

        g1, g2 = int(geom1_arr[contact_id]), int(geom2_arr[contact_id])
        if g1 in requested_geom_to_output:
            forces[requested_geom_to_output[g1]] -= world_force
        if g2 in requested_geom_to_output:
            forces[requested_geom_to_output[g2]] += world_force

    return forces

get_ground_contact_info(fly_name)

Get ground contact information for all six legs.

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required

Returns:

Type Description
Float[ndarray, 6]

A 6-tuple, one entry per leg ordered as in fly.get_legs_order():

Float[ndarray, '6 3']
  • contact_found: shape (6,) — raw found channel from the MuJoCo contact sensor.
Float[ndarray, '6 3']
  • forces: shape (6, 3) — contact force in contact frame.
Float[ndarray, '6 3']
  • torques: shape (6, 3) — contact torque in contact frame.
Float[ndarray, '6 3']
  • positions: shape (6, 3) — contact position in global frame.
Float[ndarray, '6 3']
  • normals: shape (6, 3) — contact normal in global frame.
tuple[Float[ndarray, 6], Float[ndarray, '6 3'], Float[ndarray, '6 3'], Float[ndarray, '6 3'], Float[ndarray, '6 3'], Float[ndarray, '6 3']]
  • tangents: shape (6, 3) — contact tangent in global frame.
Source code in src/flygym/simulation.py
def get_ground_contact_info(
    self, fly_name: str
) -> tuple[
    Float[np.ndarray, "6"],  # contact/no contact flag
    Float[np.ndarray, "6 3"],  # force (in contact frame)
    Float[np.ndarray, "6 3"],  # torque (in contact frame)
    Float[np.ndarray, "6 3"],  # pos (in global frame)
    Float[np.ndarray, "6 3"],  # normal (in global frame)
    Float[np.ndarray, "6 3"],  # tangent (in global frame)
]:
    """Get ground contact information for all six legs.

    Args:
        fly_name: Name of the fly.

    Returns:
        A 6-tuple, one entry per leg ordered as in ``fly.get_legs_order()``:

        - ``contact_found``: shape ``(6,)`` — raw ``found`` channel from the
            MuJoCo contact sensor.
        - ``forces``: shape ``(6, 3)`` — contact force in contact frame.
        - ``torques``: shape ``(6, 3)`` — contact torque in contact frame.
        - ``positions``: shape ``(6, 3)`` — contact position in global frame.
        - ``normals``: shape ``(6, 3)`` — contact normal in global frame.
        - ``tangents``: shape ``(6, 3)`` — contact tangent in global frame.
    """
    internal_ids = self._intern_groundcontactsensorids_by_fly[fly_name]
    sensor_data = self.mj_data.sensordata[internal_ids]
    # Reshape (6 legs * 16 dims per sensor,) to (6 legs, 16 dim per sensor)
    sensor_data = sensor_data.reshape(6, 16)
    contact_found = sensor_data[:, 0]
    forces = sensor_data[:, 1:4]
    torques = sensor_data[:, 4:7]
    positions = sensor_data[:, 7:10]
    normals = sensor_data[:, 10:13]
    tangents = sensor_data[:, 13:]
    return contact_found, forces, torques, positions, normals, tangents

get_joint_angles(fly_name)

Get current joint angles ordered by the fly's skeleton.

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required

Returns:

Type Description
Float[ndarray, n_jointdofs]

Joint angles in radians, shape (n_jointdofs,), ordered as in

Float[ndarray, n_jointdofs]

fly.get_jointdofs_order().

Source code in src/flygym/simulation.py
def get_joint_angles(self, fly_name: str) -> Float[np.ndarray, "n_jointdofs"]:  # noqa: F821
    """Get current joint angles ordered by the fly's skeleton.

    Args:
        fly_name: Name of the fly.

    Returns:
        Joint angles in radians, shape ``(n_jointdofs,)``, ordered as in
        ``fly.get_jointdofs_order()``.
    """
    internal_ids = self._intern_qposadrs_by_fly[fly_name]
    return self.mj_data.qpos[internal_ids]

get_joint_velocities(fly_name)

Get current joint angular velocities ordered by the fly's skeleton.

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required

Returns:

Type Description
Float[ndarray, n_jointdofs]

Joint velocities in radians per second, shape (n_jointdofs,), ordered

Float[ndarray, n_jointdofs]

as in fly.get_jointdofs_order().

Source code in src/flygym/simulation.py
def get_joint_velocities(self, fly_name: str) -> Float[np.ndarray, "n_jointdofs"]:  # noqa: F821
    """Get current joint angular velocities ordered by the fly's skeleton.

    Args:
        fly_name: Name of the fly.

    Returns:
        Joint velocities in radians per second, shape ``(n_jointdofs,)``, ordered
        as in ``fly.get_jointdofs_order()``.
    """
    internal_ids = self._intern_qveladrs_by_fly[fly_name]
    return self.mj_data.qvel[internal_ids]

get_ommatidia_readouts(fly_name)

Convert the rendered eye frames into ommatidia readouts.

Parameters:

Name Type Description Default
fly_name str

Name of the fly to query.

required

Returns:

Type Description
Float[ndarray, 'n_cameras n_ommatidia 2']

A float32 array with shape (2, n_ommatidia, 2) containing

Float[ndarray, 'n_cameras n_ommatidia 2']

the yellow/pale channel readings for each eye camera. The first dimension

Float[ndarray, 'n_cameras n_ommatidia 2']

corresponds to the left and right eyes, in that order. The last

Float[ndarray, 'n_cameras n_ommatidia 2']

dimension corresponds to the yellow- and pale-type ommatidia, in that

Float[ndarray, 'n_cameras n_ommatidia 2']

order. Zero values indicate that the ommatidium is of the other type.

Float[ndarray, 'n_cameras n_ommatidia 2']

For example, if readouts[0, 5, 0] is 0, it means that the 5th ommatidium

Float[ndarray, 'n_cameras n_ommatidia 2']

is of pale type, and the user should look at readouts[0, 5, 1] instead.

Source code in src/flygym/simulation.py
def get_ommatidia_readouts(
    self, fly_name: str
) -> Float[np.ndarray, "n_cameras n_ommatidia 2"]:
    """Convert the rendered eye frames into ommatidia readouts.

    Args:
        fly_name: Name of the fly to query.

    Returns:
        A float32 array with shape ``(2, n_ommatidia, 2)`` containing
        the yellow/pale channel readings for each eye camera. The first dimension
        corresponds to the left and right eyes, in that order. The last
        dimension corresponds to the yellow- and pale-type ommatidia, in that
        order. Zero values indicate that the ommatidium is of the other type.
        For example, if `readouts[0, 5, 0]` is 0, it means that the 5th ommatidium
        is of pale type, and the user should look at `readouts[0, 5, 1]` instead.
    """
    raw_vision = self.get_raw_vision(fly_name)
    ommatidia_readouts = np.array(
        [self.retina.raw_image_to_hex_pxls(image) for image in raw_vision],
        dtype=np.float32,
    )
    return ommatidia_readouts

get_raw_vision(fly_name)

Render the fly's eye cameras and return fisheye-corrected frames.

Certain body parts are invisible to the eye cameras to avoid self-occlusion, as configured in flygym/assets/model/neuromechfly/vision.yaml. These geoms are assigned to geom group 2, which is rendered by the MuJoCo renderer by default, but the eye renderer within FlyGym is configured to ignore this geom group.

Parameters:

Name Type Description Default
fly_name str

Name of the fly to query.

required

Returns:

Type Description
Float[ndarray, '2 height width 3']

An array of shape (2, height, width, 3) containing the RGB images from the

Float[ndarray, '2 height width 3']

fly's two eyes. The first dimension corresponds to the left and right eye,

Float[ndarray, '2 height width 3']

in that order.

Source code in src/flygym/simulation.py
def get_raw_vision(self, fly_name: str) -> Float[np.ndarray, "2 height width 3"]:
    """Render the fly's eye cameras and return fisheye-corrected frames.

    Certain body parts are invisible to the eye cameras to avoid self-occlusion, as
    configured in `flygym/assets/model/neuromechfly/vision.yaml`. These geoms are assigned to
    geom group 2, which _is_ rendered by the MuJoCo renderer by default, but the eye
    renderer within FlyGym is configured to ignore this geom group.

    Args:
        fly_name: Name of the fly to query.

    Returns:
        An array of shape (2, height, width, 3) containing the RGB images from the
        fly's two eyes. The first dimension corresponds to the left and right eye,
        in that order.
    """
    try:
        internal_eye_camera_ids = self._intern_eye_camera_ids_by_fly[fly_name]
    except KeyError:
        raise ValueError(
            f"Fly '{fly_name}' does not have any eye cameras defined. "
            "Make sure to call fly.add_vision() when constructing the fly."
        )

    # Lazy-construct Retina and eye renderer only if user queries visual input
    if self.retina is None:
        from flygym.vision.retina import Retina

        self.retina = Retina()

    if self.eye_renderer is None:
        self.eye_renderer = mj.Renderer(
            self.mj_model,
            height=self.retina.nrows,
            width=self.retina.ncols,
        )
        # Make eye renderer apply option to ignore geoms in group 2, which includes
        # body segments that should not be rendered by the eye cameras to avoid
        # self-occlusion. Disable group 1 as well because markers for eye positions
        # belong to group 1.
        self.eye_renderer_scene_option = mj.MjvOption()
        self.eye_renderer_scene_option.geomgroup[1] = 0
        self.eye_renderer_scene_option.geomgroup[2] = 0

    # Render each eye camera and apply fisheye correction
    frames = []
    for cam_id in internal_eye_camera_ids:
        self.eye_renderer.update_scene(
            self.mj_data, cam_id, scene_option=self.eye_renderer_scene_option
        )
        raw_frame = self.eye_renderer.render()
        fish_img = self.retina.correct_fisheye(raw_frame)
        frames.append(fish_img)
    return np.array(frames)

get_site_positions(fly_name)

Get global 3D positions of anatomical-joint sites.

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required

Returns:

Type Description
Float[ndarray, 'n_sites 3']

Site positions in mm, shape (n_sites, 3), ordered as in

Float[ndarray, 'n_sites 3']

fly.get_sites_order().

Source code in src/flygym/simulation.py
def get_site_positions(self, fly_name: str) -> Float[np.ndarray, "n_sites 3"]:
    """Get global 3D positions of anatomical-joint sites.

    Args:
        fly_name: Name of the fly.

    Returns:
        Site positions in mm, shape ``(n_sites, 3)``, ordered as in
        ``fly.get_sites_order()``.
    """
    internal_ids = self._internal_siteids_by_fly[fly_name]
    return self.mj_data.site_xpos[internal_ids, :]

print_performance_report(show_in_notebook='auto')

Print a summary of physics and rendering performance.

Requires that step_with_profile and render_as_needed_with_profile were used during the simulation loop.

Parameters:

Name Type Description Default
show_in_notebook bool | Literal['auto']

If True, render the report as an HTML table suitable for display in a Jupyter notebook. If "auto", will attempt to detect if we're in a notebook environment and choose accordingly.

'auto'
Source code in src/flygym/simulation.py
def print_performance_report(
    self, show_in_notebook: bool | Literal["auto"] = "auto"
) -> None:
    """Print a summary of physics and rendering performance.

    Requires that `step_with_profile` and `render_as_needed_with_profile` were
    used during the simulation loop.

    Args:
        show_in_notebook: If True, render the report as an HTML table suitable for
            display in a Jupyter notebook. If "auto", will attempt to detect if
            we're in a notebook environment and choose accordingly.
    """
    print_perf_report(
        n_steps=self._curr_step,
        n_frames_rendered=self._frames_rendered,
        total_physics_time_ns=self._total_physics_time_ns,
        total_render_time_ns=self._total_render_time_ns,
        timestep=self.timestep,
        show_in_notebook=show_in_notebook,
    )

render_as_needed()

Render a frame if enough simulation time has elapsed since the last render.

Returns:

Type Description
bool

True if a frame was rendered, False otherwise.

Source code in src/flygym/simulation.py
def render_as_needed(self) -> bool:
    """Render a frame if enough simulation time has elapsed since the last render.

    Returns:
        True if a frame was rendered, False otherwise.
    """
    return self.renderer.render_as_needed(self.mj_data)

render_as_needed_with_profile()

Like render_as_needed, but also accumulates render timing data.

Source code in src/flygym/simulation.py
def render_as_needed_with_profile(self) -> bool:
    """Like `render_as_needed`, but also accumulates render timing data."""
    render_start_ns = perf_counter_ns()
    render_done = self.render_as_needed()
    render_finish_ns = perf_counter_ns()
    self._total_render_time_ns += render_finish_ns - render_start_ns
    if render_done:
        self._frames_rendered += 1
    return render_done

reset()

Reset simulation and renderer to the neutral keyframe.

Source code in src/flygym/simulation.py
def reset(self) -> None:
    """Reset simulation and renderer to the neutral keyframe."""
    # Reset physics
    mj.mj_resetDataKeyframe(self.mj_model, self.mj_data, self._neutral_keyframe_id)

    # Reset renderers
    if self.renderer is not None:
        self.renderer.reset()
    # The eye renderer doesn't have to be reset as it's stateless (it's the plain
    # MuJoCo renderer, not our flygym.rendering.Renderer)

    # Stuff for performance profiling
    self._curr_step = 0
    self._frames_rendered = 0
    self._total_physics_time_ns = 0
    self._total_render_time_ns = 0

set_actuator_inputs(fly_name, actuator_type, inputs)

Set control inputs for the given actuator type.

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required
actuator_type ActuatorType

Type of actuator to control.

required
inputs Float[ndarray, n_actuators]

Control inputs, shape (n_actuators,), ordered as in fly.get_actuated_jointdofs_order(actuator_type).

required
Source code in src/flygym/simulation.py
def set_actuator_inputs(
    self,
    fly_name: str,
    actuator_type: ActuatorType,
    inputs: Float[np.ndarray, "n_actuators"],  # noqa: F821
) -> None:
    """Set control inputs for the given actuator type.

    Args:
        fly_name: Name of the fly.
        actuator_type: Type of actuator to control.
        inputs: Control inputs, shape ``(n_actuators,)``, ordered as in
            ``fly.get_actuated_jointdofs_order(actuator_type)``.
    """
    internal_ids = self._intern_actuatorids_by_type_by_fly[actuator_type][fly_name]
    if len(inputs) != len(internal_ids):
        raise ValueError(
            f"Expected {len(internal_ids)} inputs for actuator type "
            f"'{actuator_type.name}', but got {len(inputs)}"
        )
    self.mj_data.ctrl[internal_ids] = inputs

set_leg_adhesion_states(fly_name, leg_to_adhesion_state)

Set adhesion states for each leg.

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required
leg_to_adhesion_state Float[ndarray, 6]

Adhesion control per leg, shape (6,), ordered as in fly.get_legs_order(). Values should be in the range [0, 1].

required
Source code in src/flygym/simulation.py
def set_leg_adhesion_states(
    self, fly_name: str, leg_to_adhesion_state: Float[np.ndarray, "6"]
) -> None:
    """Set adhesion states for each leg.

    Args:
        fly_name: Name of the fly.
        leg_to_adhesion_state: Adhesion control per leg, shape ``(6,)``, ordered as
            in ``fly.get_legs_order()``. Values should be in the range ``[0, 1]``.
    """
    internal_ids = self._intern_adhesionactuatorids_by_fly[fly_name]
    if len(leg_to_adhesion_state) != len(internal_ids):
        raise ValueError(
            "Unexpected number of adhesion states: "
            f"expected {len(internal_ids)}, got {len(leg_to_adhesion_state)}"
        )
    self.mj_data.ctrl[internal_ids] = leg_to_adhesion_state

set_renderer(cameras, *, camera_res=(240, 320), playback_speed=0.2, output_fps=25, buffer_frames=True, scene_option=None, **kwargs)

Attach a renderer to this simulation.

Parameters:

Name Type Description Default
cameras str | MjsCamera | list[str | MjsCamera]

Camera(s) to render. Can be a camera name, MJCF camera element, or a sequence of either.

required
camera_res tuple[int, int]

(height, width) in pixels.

(240, 320)
playback_speed float

Video playback speed relative to real time.

0.2
output_fps int

Output video frame rate.

25
buffer_frames bool

If True, store rendered frames in memory.

True
scene_option MjvOption | None

MuJoCo scene options. Uses defaults if None.

None
**kwargs Any

Passed to mujoco.Renderer.

{}

Returns:

Type Description
Renderer

The created Renderer instance.

Source code in src/flygym/simulation.py
def set_renderer(
    self,
    cameras: str | mj.MjsCamera | list[str | mj.MjsCamera],
    *,
    camera_res: tuple[int, int] = (240, 320),
    playback_speed: float = 0.2,
    output_fps: int = 25,
    buffer_frames: bool = True,
    scene_option: mj.MjvOption | None = None,
    **kwargs: Any,
) -> Renderer:
    """Attach a renderer to this simulation.

    Args:
        cameras: Camera(s) to render. Can be a camera name, MJCF camera element,
            or a sequence of either.
        camera_res: ``(height, width)`` in pixels.
        playback_speed: Video playback speed relative to real time.
        output_fps: Output video frame rate.
        buffer_frames: If True, store rendered frames in memory.
        scene_option: MuJoCo scene options. Uses defaults if None.
        **kwargs: Passed to ``mujoco.Renderer``.

    Returns:
        The created `Renderer` instance.
    """
    self.renderer = Renderer(
        self.mj_model,
        cameras,
        camera_res=camera_res,
        playback_speed=playback_speed,
        output_fps=output_fps,
        buffer_frames=buffer_frames,
        scene_option=scene_option,
        **kwargs,
    )
    return self.renderer

set_tendon_actuator_inputs(fly_name, inputs)

Set control inputs for tendon actuators.

Parameters:

Name Type Description Default
fly_name str

Name of the fly.

required
inputs Float[ndarray, n_tendon_actuators]

Control inputs, shape (n_tendon_actuators,), ordered as in fly.get_actuated_jointdofs_order(ActuatorType.TENDON).

required
Source code in src/flygym/simulation.py
def set_tendon_actuator_inputs(
    self,
    fly_name: str,
    inputs: Float[np.ndarray, "n_tendon_actuators"],  # noqa: F821
) -> None:
    """Set control inputs for tendon actuators.

    Args:
        fly_name: Name of the fly.
        inputs: Control inputs, shape ``(n_tendon_actuators,)``, ordered as in
            ``fly.get_actuated_jointdofs_order(ActuatorType.TENDON)``.
    """
    # Flies without tendon actuators have no entry in the lookup, so default to
    # an empty id array: the length check below then turns an empty input into a
    # clean no-op and a non-empty one into a clear "expected 0" error.
    internal_ids = self._intern_tendonactuatorids_by_fly.get(
        fly_name, np.empty(0, dtype=np.int32)
    )
    if len(inputs) != len(internal_ids):
        raise ValueError(
            f"Expected {len(internal_ids)} tendon actuator inputs, but got "
            f"{len(inputs)}"
        )
    self.mj_data.ctrl[internal_ids] = inputs

step()

Advance physics by one timestep.

Source code in src/flygym/simulation.py
def step(self) -> None:
    """Advance physics by one timestep."""
    mj.mj_step(self.mj_model, self.mj_data)

step_with_profile()

Advance physics by one timestep, accumulating timing data for profiling.

Source code in src/flygym/simulation.py
def step_with_profile(self) -> None:
    """Advance physics by one timestep, accumulating timing data for profiling."""
    physics_start_ns = perf_counter_ns()
    self.step()
    physics_finish_ns = perf_counter_ns()
    self._total_physics_time_ns += physics_finish_ns - physics_start_ns
    self._curr_step += 1

warmup(duration_s=0.05)

Step the simulation for a short period to settle initialization transients.

Call after reset and before the main simulation loop to allow the fly to settle onto the ground.

Parameters:

Name Type Description Default
duration_s float

Duration of the warmup period in seconds.

0.05
Source code in src/flygym/simulation.py
def warmup(self, duration_s: float = 0.05) -> None:
    """Step the simulation for a short period to settle initialization transients.

    Call after `reset` and before the main simulation loop to allow the fly to
    settle onto the ground.

    Args:
        duration_s: Duration of the warmup period in seconds.
    """
    n_steps = int(duration_s / self.mj_model.opt.timestep)
    for _ in range(n_steps):
        self.step()