Skip to content

world

BaseWorld

Bases: BaseCompositionElement, ABC

Base class for worlds that contain environmental features that the fly can interact with (e.g., ground) and define how flies are attached to the world (e.g., free-floating or tethered). A world can contain multiple flies that can interact with one another.

Concrete subclasses typically override __init__ to set up environmental features (e.g., ground plane) and _attach_fly_mjcf to define how flies are attached. See method documentation below for details.

Attributes:

Name Type Description
name

Name of the world.

fly_lookup dict[str, Fly]

A dictionary mapping fly names to Fly objects in the world.

mjcf_root RootElement

The root element of the world's MJCF model (fly MJCF models are attached to this root).

world_dof_neutral_states

A dictionary mapping names of DoFs managed by the world (e.g., free joints by which flies are attached to the world) to their neutral state values. The neutral state is 1D for slide and hinge joints, 4D for ball joints (quaternion), and 7D for free joints (position + orientation).

Source code in src/flygym/compose/world.py
class BaseWorld(BaseCompositionElement, ABC):
    """Base class for worlds that contain environmental features that the fly can
    interact with (e.g., ground) and define how flies are attached to the world (e.g.,
    free-floating or tethered). A world can contain multiple flies that can interact
    with one another.

    Concrete subclasses typically override `__init__` to set up environmental features
    (e.g., ground plane) and `_attach_fly_mjcf` to define how flies are attached. See
    method documentation below for details.

    Attributes:
        name:
            Name of the world.
        fly_lookup:
            A dictionary mapping fly names to `Fly` objects in the world.
        mjcf_root:
            The root element of the world's MJCF model (fly MJCF models are attached to
            this root).
        world_dof_neutral_states:
            A dictionary mapping names of DoFs managed by the world (e.g., free joints
            by which flies are attached to the world) to their neutral state values.
            The neutral state is 1D for slide and hinge joints, 4D for ball joints
            (quaternion), and 7D for free joints (position + orientation).
    """

    def __init__(self, name: str) -> None:
        """Initialize the world and its underlying MJCF model.

        Concrete subclasses should call this first (i.e., `super().__init__(name)`) as
        it sets up a few essential attributes.
        """
        self._mjcf_root = mjcf.RootElement(model=name)
        self._fly_lookup: dict[str, Fly] = {}
        self.world_dof_neutral_states = {}
        self._neutral_keyframe = self.mjcf_root.keyframe.add(
            "key", name="neutral", time=0
        )

    @override
    @property
    def mjcf_root(self) -> mjcf.RootElement:
        return self._mjcf_root

    @property
    def fly_lookup(self) -> dict[str, Fly]:
        """Lookup for `Fly` objects in the world, keyed by fly name."""
        return self._fly_lookup

    @abstractmethod
    def _attach_fly_mjcf(
        self,
        fly: Fly,
        spawn_position: Vec3,
        spawn_rotation: Rotation3D,
        *args,
        **kwargs,
    ) -> mjcf.Element:
        """Attach the fly's MJCF root to the world MJCF model.

        Concrete subclasses should implement this method instead of overriding
        `add_fly()` directly. The `add_fly()` method handles registering the fly under
        `fly_lookup` and updating neutral states; this method is responsible only for
        connecting the fly's MJCF model to the world's MJCF model.

        Use `dm_control.mjcf`'s `attach()` method to attach the fly's MJCF model. See
        `FlatGroundWorld` for an example. More details can be found in the
        [`dm_control.mjcf` documentation](https://github.com/google-deepmind/dm_control/tree/main/dm_control/mjcf#attaching-models).

        Returns:
            The free joint element created by the attachment.
        """
        pass

    def add_fly(
        self,
        fly: Fly,
        spawn_position: Vec3,
        spawn_rotation: Rotation3D,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """Attach a fly to the world at the specified pose.

        The fly's MJCF model is merged into the world and registered under
        `fly_lookup`. Extra keyword arguments are forwarded to the subclass
        `_attach_fly_mjcf` implementation (see the specific world subclass for
        available options).

        Args:
            fly: The fly to add.
            spawn_position: Initial ``(x, y, z)`` position in mm.
            spawn_rotation: Initial orientation as a `Rotation3D` in quaternion format.
            *args: Forwarded to `_attach_fly_mjcf`.
            **kwargs: Forwarded to `_attach_fly_mjcf`.

        Raises:
            ValueError: If a fly with the same name already exists in the world.
            ValueError: If ``spawn_rotation`` is not in quaternion format.
        """
        # Register fly in the fly lookup
        if fly.name in self._fly_lookup:
            raise ValueError(f"Fly with name '{fly.name}' already exists in the world.")
        self._fly_lookup[fly.name] = fly

        # Remove neutral keyframes that are already generated by the fly. Neutral states
        # are globally managed at the world level. A single neutral keyframe will be
        # managed by the world from now on.
        neutral_keyframe = fly.mjcf_root.keyframe.find("key", "neutral")
        if neutral_keyframe is not None:
            neutral_keyframe.remove()

        # Attach the fly's MJCF root to the world MJCF model with a free joint.
        # This is an abstract method that must be implemented by concrete world classes.
        freejoint = self._attach_fly_mjcf(
            fly, spawn_position, spawn_rotation, *args, **kwargs
        )

        # Set neutral state for the freejoint attaching the fly to the world
        # (freejoint state is in [x, y, z, qw, qx, qy, qz] format)
        if spawn_rotation.format != "quat":
            raise ValueError(
                "Freejoint neutral rotation can only be specified in quaternion format "
                f"for now. Got {spawn_rotation}."
            )
        neutral_state = [*spawn_position, *spawn_rotation.values]
        self.world_dof_neutral_states[freejoint.full_identifier] = neutral_state

        self._rebuild_neutral_keyframe()

    def _rebuild_neutral_keyframe(self):
        mj_model, _ = self.compile()
        neutral_qpos = np.zeros(mj_model.nq)
        neutral_ctrl = np.zeros(mj_model.nu)

        # Step 1: set neutral qpos for DoFs created by the world
        # dm_control.mjcf has trouble finding freejoints by name with
        # .find("joint", freejoint_name), but they do show up in the list of all joints
        # obtained with .find_all("joint"). So we build a mapping manually in order to
        # set the neutral pose for freejoints corresponding to fly spawns.
        all_world_joints = {
            j.full_identifier: j for j in self.mjcf_root.find_all("joint")
        }
        for joint_name, neutral_state in self.world_dof_neutral_states.items():
            joint_element = all_world_joints.get(joint_name)
            if joint_element is None:
                raise RuntimeError(
                    f"Joint '{joint_name}' not found when rebuilding neutral keyframe."
                )
            joint_type = (
                "free" if joint_element.tag == "freejoint" else joint_element.type
            )
            internal_jointid = mj.mj_name2id(
                mj_model, mj.mjtObj.mjOBJ_JOINT, joint_element.full_identifier
            )
            dofadr_start = mj_model.jnt_dofadr[internal_jointid]
            dofadr_end = dofadr_start + _STATE_DIM_BY_JOINT_TYPE[joint_type]
            neutral_qpos[dofadr_start:dofadr_end] = neutral_state

        # Step 2: handle joints and actuators belonging to flies attached to the world
        for fly_name, fly in self.fly_lookup.items():
            # Copy neutral joint angles from fly
            qpos_filled_by_fly = fly._get_neutral_qpos(mj_model)
            indices_to_fill = qpos_filled_by_fly.nonzero()
            has_conflict = np.any(~np.isclose(neutral_qpos[indices_to_fill], 0))
            if has_conflict:
                raise FlyGymInternalError(
                    f"Conflict in neutral joint angles: fly '{fly_name}' is trying "
                    "to set neutral qpos values for DoFs that already have their "
                    "neutral qpos set."
                )
            neutral_qpos[indices_to_fill] = qpos_filled_by_fly[indices_to_fill]

            # Copy neutral actuator inputs from fly
            ctrl_filled_by_fly = fly._get_neutral_ctrl(mj_model)
            indices_to_fill = ctrl_filled_by_fly.nonzero()
            has_conflict = np.any(~np.isclose(neutral_ctrl[indices_to_fill], 0))
            if has_conflict:
                raise FlyGymInternalError(
                    f"Conflict in neutral actuator inputs: fly '{fly_name}' is trying "
                    "to set neutral ctrl values for actuators that already have their "
                    "neutral ctrl set."
                )
            neutral_ctrl[indices_to_fill] = ctrl_filled_by_fly[indices_to_fill]

        self._neutral_keyframe.qpos = neutral_qpos
        self._neutral_keyframe.ctrl = neutral_ctrl

fly_lookup property

Lookup for Fly objects in the world, keyed by fly name.

__init__(name)

Initialize the world and its underlying MJCF model.

Concrete subclasses should call this first (i.e., super().__init__(name)) as it sets up a few essential attributes.

Source code in src/flygym/compose/world.py
def __init__(self, name: str) -> None:
    """Initialize the world and its underlying MJCF model.

    Concrete subclasses should call this first (i.e., `super().__init__(name)`) as
    it sets up a few essential attributes.
    """
    self._mjcf_root = mjcf.RootElement(model=name)
    self._fly_lookup: dict[str, Fly] = {}
    self.world_dof_neutral_states = {}
    self._neutral_keyframe = self.mjcf_root.keyframe.add(
        "key", name="neutral", time=0
    )

add_fly(fly, spawn_position, spawn_rotation, *args, **kwargs)

Attach a fly to the world at the specified pose.

The fly's MJCF model is merged into the world and registered under fly_lookup. Extra keyword arguments are forwarded to the subclass _attach_fly_mjcf implementation (see the specific world subclass for available options).

Parameters:

Name Type Description Default
fly Fly

The fly to add.

required
spawn_position Vec3

Initial (x, y, z) position in mm.

required
spawn_rotation Rotation3D

Initial orientation as a Rotation3D in quaternion format.

required
*args Any

Forwarded to _attach_fly_mjcf.

()
**kwargs Any

Forwarded to _attach_fly_mjcf.

{}

Raises:

Type Description
ValueError

If a fly with the same name already exists in the world.

ValueError

If spawn_rotation is not in quaternion format.

Source code in src/flygym/compose/world.py
def add_fly(
    self,
    fly: Fly,
    spawn_position: Vec3,
    spawn_rotation: Rotation3D,
    *args: Any,
    **kwargs: Any,
) -> None:
    """Attach a fly to the world at the specified pose.

    The fly's MJCF model is merged into the world and registered under
    `fly_lookup`. Extra keyword arguments are forwarded to the subclass
    `_attach_fly_mjcf` implementation (see the specific world subclass for
    available options).

    Args:
        fly: The fly to add.
        spawn_position: Initial ``(x, y, z)`` position in mm.
        spawn_rotation: Initial orientation as a `Rotation3D` in quaternion format.
        *args: Forwarded to `_attach_fly_mjcf`.
        **kwargs: Forwarded to `_attach_fly_mjcf`.

    Raises:
        ValueError: If a fly with the same name already exists in the world.
        ValueError: If ``spawn_rotation`` is not in quaternion format.
    """
    # Register fly in the fly lookup
    if fly.name in self._fly_lookup:
        raise ValueError(f"Fly with name '{fly.name}' already exists in the world.")
    self._fly_lookup[fly.name] = fly

    # Remove neutral keyframes that are already generated by the fly. Neutral states
    # are globally managed at the world level. A single neutral keyframe will be
    # managed by the world from now on.
    neutral_keyframe = fly.mjcf_root.keyframe.find("key", "neutral")
    if neutral_keyframe is not None:
        neutral_keyframe.remove()

    # Attach the fly's MJCF root to the world MJCF model with a free joint.
    # This is an abstract method that must be implemented by concrete world classes.
    freejoint = self._attach_fly_mjcf(
        fly, spawn_position, spawn_rotation, *args, **kwargs
    )

    # Set neutral state for the freejoint attaching the fly to the world
    # (freejoint state is in [x, y, z, qw, qx, qy, qz] format)
    if spawn_rotation.format != "quat":
        raise ValueError(
            "Freejoint neutral rotation can only be specified in quaternion format "
            f"for now. Got {spawn_rotation}."
        )
    neutral_state = [*spawn_position, *spawn_rotation.values]
    self.world_dof_neutral_states[freejoint.full_identifier] = neutral_state

    self._rebuild_neutral_keyframe()

FlatGroundWorld

Bases: BaseWorld

World with a flat infinite ground plane. Flies are free to move.

When calling add_fly, the following extra keyword arguments are accepted:

  • bodysegs_with_ground_contact: Body segments that collide with the ground. Accepts a ContactBodiesPreset, a preset string, or a collection of BodySegment objects. Default: ContactBodiesPreset.LEGS_THORAX_ABDOMEN_HEAD.
  • ground_contact_params: ContactParams for friction and contact physics. Default: ContactParams().
  • add_ground_contact_sensors: If True, add contact force sensors for each leg. Default: True.

Parameters:

Name Type Description Default
name str

Name of the world.

'flat_ground_world'
half_size float

Half-size of the ground plane in mm.

1000
Source code in src/flygym/compose/world.py
class FlatGroundWorld(BaseWorld):
    """World with a flat infinite ground plane. Flies are free to move.

    When calling `add_fly`, the following extra keyword arguments are accepted:

    - ``bodysegs_with_ground_contact``: Body segments that collide with the ground.
      Accepts a `ContactBodiesPreset`, a preset string, or a collection of
      `BodySegment` objects. Default: ``ContactBodiesPreset.LEGS_THORAX_ABDOMEN_HEAD``.
    - ``ground_contact_params``: `ContactParams` for friction and contact physics.
      Default: ``ContactParams()``.
    - ``add_ground_contact_sensors``: If True, add contact force sensors for each leg.
      Default: ``True``.

    Args:
        name: Name of the world.
        half_size: Half-size of the ground plane in mm.
    """

    @override
    def __init__(
        self, name: str = "flat_ground_world", *, half_size: float = 1000
    ) -> None:
        super().__init__(name=name)

        checker_texture = self.mjcf_root.asset.add(
            "texture",
            name="checker",
            type="2d",
            builtin="checker",
            width=300,
            height=300,
            rgb1=(0.3, 0.3, 0.3),
            rgb2=(0.4, 0.4, 0.4),
        )
        grid_material = self.mjcf_root.asset.add(
            "material",
            name="grid",
            texture=checker_texture,
            texrepeat=(250, 250),
            reflectance=0.2,
        )
        self.ground_geom = self.mjcf_root.worldbody.add(
            "geom",
            type="plane",
            name="ground_plane",
            material=grid_material,
            pos=(0, 0, 0),
            size=(half_size, half_size, 1),
            contype=0,
            conaffinity=0,
        )
        self.legpos_to_groundcontactsensors_by_fly = None

    @override
    def _attach_fly_mjcf(
        self,
        fly: Fly,
        spawn_position: Vec3,
        spawn_rotation: Rotation3D,
        *,
        bodysegs_with_ground_contact: (
            list[BodySegment] | ContactBodiesPreset | str
        ) = ContactBodiesPreset.LEGS_THORAX_ABDOMEN_HEAD,
        ground_contact_params: ContactParams = ContactParams(),
        add_ground_contact_sensors: bool = True,
    ) -> mjcf.Element:
        spawn_site = self.mjcf_root.worldbody.add(
            "site", name=fly.name, pos=spawn_position, **spawn_rotation.as_kwargs()
        )
        freejoint = spawn_site.attach(fly.mjcf_root).add("freejoint", name=fly.name)

        if isinstance(bodysegs_with_ground_contact, ContactBodiesPreset | str):
            preset = ContactBodiesPreset(bodysegs_with_ground_contact)
            bodysegs_with_ground_contact = preset.to_body_segments_list()

        self._set_ground_contact(
            fly, bodysegs_with_ground_contact, ground_contact_params
        )
        if add_ground_contact_sensors:
            self._add_ground_contact_sensors(fly, bodysegs_with_ground_contact)
        return freejoint

    def _set_ground_contact(
        self,
        fly: Fly,
        bodysegs_with_ground_contact: list[BodySegment],
        ground_contact_params: ContactParams,
    ) -> None:
        for body_segment in bodysegs_with_ground_contact:
            body_geom = fly.mjcf_root.find("geom", f"{body_segment.name}")
            self.mjcf_root.contact.add(
                "pair",
                geom1=body_geom,
                geom2=self.ground_geom,
                name=f"{body_segment.name}-ground",
                friction=ground_contact_params.get_friction_tuple(),
                solref=ground_contact_params.get_solref_tuple(),
                solimp=ground_contact_params.get_solimp_tuple(),
                margin=ground_contact_params.margin,
            )

    def _add_ground_contact_sensors(
        self, fly: Fly, bodysegs_with_ground_contact: list[BodySegment]
    ) -> None:
        self.legpos_to_groundcontactsensors_by_fly = defaultdict(dict)
        contact_geoms_by_leg = defaultdict(list)
        for bodyseg in bodysegs_with_ground_contact:
            if bodyseg.is_leg():
                contact_geoms_by_leg[bodyseg.pos].append(bodyseg)
        for leg, contact_geoms in contact_geoms_by_leg.items():
            subtree_rootseg = _sort_legsegs_prox2dist(contact_geoms)[0]
            subtree_rootseg_body = fly.bodyseg_to_mjcfbody[subtree_rootseg]
            sensor = self.mjcf_root.sensor.add(
                "contact",
                subtree1=subtree_rootseg_body,
                geom2=self.ground_geom,
                num=1,
                reduce="netforce",
                data="found force torque pos normal tangent",
                name=f"ground_contact_{leg}_leg",
            )
            self.legpos_to_groundcontactsensors_by_fly[fly.name][leg] = sensor

TetheredWorld

Bases: BaseWorld

World where the fly body is fixed in space via a weld constraint.

The fly's appendages (legs, wings, etc.) can still move. Useful for motor control experiments without locomotion.

Parameters:

Name Type Description Default
name str

Name of the world.

'tethered_world'
Source code in src/flygym/compose/world.py
class TetheredWorld(BaseWorld):
    """World where the fly body is fixed in space via a weld constraint.

    The fly's appendages (legs, wings, etc.) can still move. Useful for motor control
    experiments without locomotion.

    Args:
        name: Name of the world.
    """

    @override
    def __init__(self, name: str = "tethered_world") -> None:
        super().__init__(name=name)
        # don't add ground plane
        self.legpos_to_groundcontactsensors_by_fly = None

    @override
    def _attach_fly_mjcf(
        self, fly, spawn_position: Vec3, spawn_rotation: Rotation3D
    ) -> mjcf.Element:
        spawn_site = self.mjcf_root.worldbody.add(
            "site", name=fly.name, pos=spawn_position, **spawn_rotation.as_kwargs()
        )
        freejoint = spawn_site.attach(fly.mjcf_root).add("freejoint", name=fly.name)
        self.mjcf_root.equality.add(
            "weld",
            body2="world",  # worldbody is called "world" in equality constraints
            body1=fly.mjcf_root.find("body", fly.root_segment.name).full_identifier,
            relpose=(*spawn_position, *spawn_rotation.values),
            solref=(2e-4, 1.0),
            solimp=(0.98, 0.99, 1e-5, 0.5, 3),
        )
        return freejoint