Modify world MJCF model to make it compatible with MJWarp's GPU batch rendering.
This may reduce texture and lighting realism.
Modification happens in place. Returns True if any modifications were made, False
otherwise.
Note for developers: Check if anything here can be dropped upon new MJWarp releases.
Source code in src/flygym/warp/rendering.py
| def modify_world_for_batch_rendering(world: BaseWorld) -> bool:
"""Modify world MJCF model to make it compatible with MJWarp's GPU batch rendering.
This may reduce texture and lighting realism.
Modification happens in place. Returns True if any modifications were made, False
otherwise.
Note for developers: Check if anything here can be dropped upon new MJWarp releases.
"""
is_modified = False
rgb_role = int(mj.mjtTextureRole.mjTEXROLE_RGB)
# Strip textures from fly body materials
# (rendering textures on complex meshes causes MJWarp memory corruption)
for material in world.mjcf_root.materials:
# Don't touch things that are not part of a Fly
if material.name.split("/")[0] not in world.fly_lookup:
continue
# Make wings half transparent
if "wing" in material.name:
material.rgba[3] = 0.5
# If material has a texture, remove it to reduce memory use
texture_name = material.textures[rgb_role]
if texture_name:
texture_element = world.mjcf_root.texture(texture_name)
primary_color_rgb = texture_element.rgb1
material.textures[rgb_role] = ""
material.rgba[:3] = primary_color_rgb
is_modified = True
# Adjust scale of checker materials (e.g., ground): texrepeat needs to be scaled
# down by 1000x to get the same pattern - unclear why. Only materials that still
# reference a texture (e.g. the ground checker) need this.
for material in world.mjcf_root.materials:
if material.textures[rgb_role]:
material.texrepeat = tuple(tr / 1000 for tr in material.texrepeat)
is_modified = True
# Add light above each fly explicitly
for body in world.mjcf_root.bodies:
if body.name.split("/")[-1] == "c_thorax":
warnings.warn(f"Adding overhead light for body {body.name}")
body.add_light(
name=body.name.replace("/", "-") + "-overheadlight",
mode=mj.mjtCamLight.mjCAMLIGHT_TRACK,
targetbody=body.name,
pos=(0, 0, 30),
dir=(0, 0, -1),
type=mj.mjtLightType.mjLIGHT_DIRECTIONAL,
ambient=(10, 10, 10),
diffuse=(10, 10, 10),
specular=(0.3, 0.3, 0.3),
)
is_modified = True
return is_modified
|