Skip to content

base_world

BaseWorld

Bases: BaseCompositionElement, ABC

Base class for worlds that contain environmental features that the fly can interact with (e.g., ground) and define how flies are attached to the world (e.g., free-floating or tethered). A world can contain multiple flies that can interact with one another.

Concrete subclasses typically override __init__ to set up environmental features (e.g., ground plane) and _attach_fly_mjcf to define how flies are attached. See method documentation below for details.

PyMJCF -> MjSpec migration (v2.1.0)

FlyGym 2.1.0 dropped the PyMJCF backend in favour of MuJoCo's native MjSpec API. If you are upgrading from an earlier version, see the v2.1.0 changelog for breaking changes and a migration guide.

Attributes:

Name Type Description
name

Name of the world.

fly_lookup dict[str, BaseFly]

A dictionary mapping fly names to Fly objects in the world.

mjcf_root MjSpec

The root element of the world's MJCF model (fly MJCF models are attached to this root).

world_dof_neutral_states set[str]

A set of names of DoFs managed by the world (e.g., free joints by which flies are attached to the world). The neutral pose for these DoFs is read from the compiled model's qpos0 rest configuration in _rebuild_neutral_keyframe, so only the joint names are tracked here, not explicit state values.

Source code in src/flygym/compose/world/base_world.py
class BaseWorld(BaseCompositionElement, ABC):
    """Base class for worlds that contain environmental features that the fly can
    interact with (e.g., ground) and define how flies are attached to the world (e.g.,
    free-floating or tethered). A world can contain multiple flies that can interact
    with one another.

    Concrete subclasses typically override `__init__` to set up environmental features
    (e.g., ground plane) and `_attach_fly_mjcf` to define how flies are attached. See
    method documentation below for details.

    !!! warning "PyMJCF -> MjSpec migration (v2.1.0)"

        FlyGym 2.1.0 dropped the PyMJCF backend in favour of MuJoCo's native
        ``MjSpec`` API. If you are upgrading from an earlier version, see the
        [v2.1.0 changelog](https://neuromechfly.org/changelog/#version-210)
        for breaking changes and a migration guide.

    Attributes:
        name:
            Name of the world.
        fly_lookup:
            A dictionary mapping fly names to `Fly` objects in the world.
        mjcf_root:
            The root element of the world's MJCF model (fly MJCF models are attached to
            this root).
        world_dof_neutral_states:
            A set of names of DoFs managed by the world (e.g., free joints by which
            flies are attached to the world). The neutral pose for these DoFs is read
            from the compiled model's ``qpos0`` rest configuration in
            `_rebuild_neutral_keyframe`, so only the joint names are tracked here, not
            explicit state values.
    """

    def __init__(self, name: str) -> None:
        """Initialize the world and its underlying MJCF model.

        Concrete subclasses should call this first (i.e., `super().__init__(name)`) as
        it sets up a few essential attributes.
        """
        self._mjcf_root = mj.MjSpec()
        self._mjcf_root.modelname = name
        self._fly_lookup: dict[str, BaseFly] = {}
        self.ground_geoms: list = []
        self.legpos_to_groundcontactsensors_by_fly = None
        self.world_dof_neutral_states: set[str] = set()
        self._neutral_keyframe = self.mjcf_root.add_key(name="neutral", time=0)
        self._add_skybox()

    @override
    @property
    def mjcf_root(self) -> mj.MjSpec:
        return self._mjcf_root

    @property
    def fly_lookup(self) -> dict[str, BaseFly]:
        """Lookup for `Fly` objects in the world, keyed by fly name."""
        return self._fly_lookup

    @abstractmethod
    def _attach_fly_mjcf(
        self,
        fly: BaseFly,
        spawn_position: Vec3,
        spawn_rotation: Rotation3D,
        *args,
        **kwargs,
    ) -> set[str]:
        """Attach the fly's MJCF root to the world MJCF model.

        Concrete subclasses should implement this method instead of overriding
        `add_fly()` directly. The `add_fly()` method handles registering the fly under
        `fly_lookup` and updating neutral states; this method is responsible only for
        connecting the fly's MJCF model to the world's MJCF model.

        Use `MjSpec.attach()` to attach the fly's MjSpec to the world. See
        `_GroundContactMixin` and `TetheredWorld` for examples. More details can be
        found in the [MuJoCo model editing documentation](https://mujoco.readthedocs.io/en/stable/python.html#model-editing).

        Returns:
            The names of any world-level DoFs (joints) created by this attachment.
            Their neutral pose is taken from the compiled model's ``qpos0``, so only
            the names are needed. Return an empty set if the fly is rigidly attached
            (no new DoFs).
        """
        pass

    def _add_skybox(self):
        add_texture(
            self.mjcf_root,
            name="skybox",
            type="skybox",
            builtin="gradient",
            rgb1=(1, 1, 1),
            rgb2=(1, 1, 1),
            width=10,
            height=10,
        )

    def add_fly(
        self,
        fly: BaseFly,
        spawn_position: Vec3,
        spawn_rotation: Rotation3D,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """Attach a fly to the world at the specified pose.

        The fly's MJCF model is merged into the world and registered under
        `fly_lookup`. Extra keyword arguments are forwarded to the subclass
        `_attach_fly_mjcf` implementation (see the specific world subclass for
        available options).

        Args:
            fly: The fly to add.
            spawn_position: Initial ``(x, y, z)`` position in mm.
            spawn_rotation: Initial orientation as a `Rotation3D` in quaternion format.
            *args: Forwarded to `_attach_fly_mjcf`.
            **kwargs: Forwarded to `_attach_fly_mjcf`.

        Raises:
            ValueError: If a fly with the same name already exists in the world.
            ValueError: If ``spawn_rotation`` is not in quaternion format.
        """
        # Register fly in the fly lookup
        if fly.name in self._fly_lookup:
            raise ValueError(f"Fly with name '{fly.name}' already exists in the world.")
        self._fly_lookup[fly.name] = fly

        # Inherit the fly's global MuJoCo settings (timestep, gravity, integrator,
        # etc.). MjSpec.attach() uses the parent (world) spec's <option>/<compiler>
        # for the compiled model and does not merge the child's, so we apply the
        # fly's globals to the world spec here.
        set_mujoco_globals(self.mjcf_root, fly.mujoco_globals_path)

        # Remove neutral keyframes that are already generated by the fly. Neutral states
        # are globally managed at the world level. A single neutral keyframe will be
        # managed by the world from now on.
        fly_neutral_keyframes = [k for k in fly.mjcf_root.keys if k.name == "neutral"]
        for keyframe in fly_neutral_keyframes:
            fly.mjcf_root.delete(keyframe)

        # Attach the fly's MJCF root to the world MJCF model with a free joint.
        # This is an abstract method that must be implemented by concrete world classes.
        new_dofs = self._attach_fly_mjcf(
            fly, spawn_position, spawn_rotation, *args, **kwargs
        )

        # The freejoint's neutral pose is read from the compiled model's qpos0 in
        # `_rebuild_neutral_keyframe`; here we only register the new DoF names. qpos0
        # encodes the spawn orientation as a quaternion, so reject other formats.
        if spawn_rotation.format != "quat":
            raise ValueError(
                "Freejoint neutral rotation can only be specified in quaternion format "
                f"for now. Got {spawn_rotation}."
            )

        self.world_dof_neutral_states.update(new_dofs)
        self._rebuild_neutral_keyframe()

    def _rebuild_neutral_keyframe(self):
        mj_model, _ = self.compile()
        neutral_qpos = np.zeros(mj_model.nq)
        neutral_ctrl = np.zeros(mj_model.nu)

        # Step 1: set neutral qpos for DoFs created by the world (e.g. the free joints
        # by which flies are attached). Joints are keyed by their (prefixed) name,
        # which matches the compiled model's joint names. We use the compiled `qpos0`
        # rest pose, which already composes the spawn site transform with each fly's
        # root-body offset. This matches the previous (PyMJCF) behavior, where
        # `spawn_position` positions the fly's attachment frame rather than its root
        # body directly.
        all_world_joints = {j.name: j for j in self.mjcf_root.joints}
        for joint_name in self.world_dof_neutral_states:
            joint_element = all_world_joints.get(joint_name)
            if joint_element is None:
                raise RuntimeError(
                    f"Joint '{joint_name}' not found when rebuilding neutral keyframe."
                )
            internal_jointid = mj.mj_name2id(
                mj_model, mj.mjtObj.mjOBJ_JOINT, joint_element.name
            )
            qposadr_start = mj_model.jnt_qposadr[internal_jointid]
            qposadr_end = qposadr_start + _STATE_DIM_BY_JOINT_TYPE[joint_element.type]
            neutral_qpos[qposadr_start:qposadr_end] = mj_model.qpos0[
                qposadr_start:qposadr_end
            ]

        # Step 2: handle joints and actuators belonging to flies attached to the world
        for fly_name, fly in self.fly_lookup.items():
            # Copy neutral joint angles from fly
            qpos_filled_by_fly = fly._get_neutral_qpos(mj_model)
            indices_to_fill = qpos_filled_by_fly.nonzero()
            has_conflict = np.any(~np.isclose(neutral_qpos[indices_to_fill], 0))
            if has_conflict:
                raise FlyGymInternalError(
                    f"Conflict in neutral joint angles: fly '{fly_name}' is trying "
                    "to set neutral qpos values for DoFs that already have their "
                    "neutral qpos set."
                )
            neutral_qpos[indices_to_fill] = qpos_filled_by_fly[indices_to_fill]

            # Copy neutral actuator inputs from fly
            ctrl_filled_by_fly = fly._get_neutral_ctrl(mj_model)
            indices_to_fill = ctrl_filled_by_fly.nonzero()
            has_conflict = np.any(~np.isclose(neutral_ctrl[indices_to_fill], 0))
            if has_conflict:
                raise FlyGymInternalError(
                    f"Conflict in neutral actuator inputs: fly '{fly_name}' is trying "
                    "to set neutral ctrl values for actuators that already have their "
                    "neutral ctrl set."
                )
            neutral_ctrl[indices_to_fill] = ctrl_filled_by_fly[indices_to_fill]

        self._neutral_keyframe.qpos = neutral_qpos
        self._neutral_keyframe.ctrl = neutral_ctrl

fly_lookup property

Lookup for Fly objects in the world, keyed by fly name.

__init__(name)

Initialize the world and its underlying MJCF model.

Concrete subclasses should call this first (i.e., super().__init__(name)) as it sets up a few essential attributes.

Source code in src/flygym/compose/world/base_world.py
def __init__(self, name: str) -> None:
    """Initialize the world and its underlying MJCF model.

    Concrete subclasses should call this first (i.e., `super().__init__(name)`) as
    it sets up a few essential attributes.
    """
    self._mjcf_root = mj.MjSpec()
    self._mjcf_root.modelname = name
    self._fly_lookup: dict[str, BaseFly] = {}
    self.ground_geoms: list = []
    self.legpos_to_groundcontactsensors_by_fly = None
    self.world_dof_neutral_states: set[str] = set()
    self._neutral_keyframe = self.mjcf_root.add_key(name="neutral", time=0)
    self._add_skybox()

add_fly(fly, spawn_position, spawn_rotation, *args, **kwargs)

Attach a fly to the world at the specified pose.

The fly's MJCF model is merged into the world and registered under fly_lookup. Extra keyword arguments are forwarded to the subclass _attach_fly_mjcf implementation (see the specific world subclass for available options).

Parameters:

Name Type Description Default
fly BaseFly

The fly to add.

required
spawn_position Vec3

Initial (x, y, z) position in mm.

required
spawn_rotation Rotation3D

Initial orientation as a Rotation3D in quaternion format.

required
*args Any

Forwarded to _attach_fly_mjcf.

()
**kwargs Any

Forwarded to _attach_fly_mjcf.

{}

Raises:

Type Description
ValueError

If a fly with the same name already exists in the world.

ValueError

If spawn_rotation is not in quaternion format.

Source code in src/flygym/compose/world/base_world.py
def add_fly(
    self,
    fly: BaseFly,
    spawn_position: Vec3,
    spawn_rotation: Rotation3D,
    *args: Any,
    **kwargs: Any,
) -> None:
    """Attach a fly to the world at the specified pose.

    The fly's MJCF model is merged into the world and registered under
    `fly_lookup`. Extra keyword arguments are forwarded to the subclass
    `_attach_fly_mjcf` implementation (see the specific world subclass for
    available options).

    Args:
        fly: The fly to add.
        spawn_position: Initial ``(x, y, z)`` position in mm.
        spawn_rotation: Initial orientation as a `Rotation3D` in quaternion format.
        *args: Forwarded to `_attach_fly_mjcf`.
        **kwargs: Forwarded to `_attach_fly_mjcf`.

    Raises:
        ValueError: If a fly with the same name already exists in the world.
        ValueError: If ``spawn_rotation`` is not in quaternion format.
    """
    # Register fly in the fly lookup
    if fly.name in self._fly_lookup:
        raise ValueError(f"Fly with name '{fly.name}' already exists in the world.")
    self._fly_lookup[fly.name] = fly

    # Inherit the fly's global MuJoCo settings (timestep, gravity, integrator,
    # etc.). MjSpec.attach() uses the parent (world) spec's <option>/<compiler>
    # for the compiled model and does not merge the child's, so we apply the
    # fly's globals to the world spec here.
    set_mujoco_globals(self.mjcf_root, fly.mujoco_globals_path)

    # Remove neutral keyframes that are already generated by the fly. Neutral states
    # are globally managed at the world level. A single neutral keyframe will be
    # managed by the world from now on.
    fly_neutral_keyframes = [k for k in fly.mjcf_root.keys if k.name == "neutral"]
    for keyframe in fly_neutral_keyframes:
        fly.mjcf_root.delete(keyframe)

    # Attach the fly's MJCF root to the world MJCF model with a free joint.
    # This is an abstract method that must be implemented by concrete world classes.
    new_dofs = self._attach_fly_mjcf(
        fly, spawn_position, spawn_rotation, *args, **kwargs
    )

    # The freejoint's neutral pose is read from the compiled model's qpos0 in
    # `_rebuild_neutral_keyframe`; here we only register the new DoF names. qpos0
    # encodes the spawn orientation as a quaternion, so reject other formats.
    if spawn_rotation.format != "quat":
        raise ValueError(
            "Freejoint neutral rotation can only be specified in quaternion format "
            f"for now. Got {spawn_rotation}."
        )

    self.world_dof_neutral_states.update(new_dofs)
    self._rebuild_neutral_keyframe()