Skip to content

Commit

Permalink
feat: enable to use multiple rgb encoders per camera
Browse files Browse the repository at this point in the history
  • Loading branch information
HiroIshida committed Oct 24, 2024
1 parent 114870d commit 38c883b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class DiffusionConfig:
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
rgb_encoder_per_camera: bool = False
# Unet.
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5
Expand Down
42 changes: 32 additions & 10 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,14 @@ def __init__(self, config: DiffusionConfig):
self._use_env_state = False
if num_images > 0:
self._use_images = True
self.rgb_encoder = DiffusionRgbEncoder(config)
global_cond_dim += self.rgb_encoder.feature_dim * num_images
if self.config.rgb_encoder_per_camera:
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
self.rgb_encoder = nn.ModuleList(encoders)
global_cond_dim += encoders[0].feature_dim * num_images
else:
self.rgb_encoder = DiffusionRgbEncoder(config)
global_cond_dim += self.rgb_encoder.feature_dim * num_images
print(f"gloabl cond dim: {global_cond_dim}")
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
global_cond_dim += config.input_shapes["observation.environment_state"][0]
Expand Down Expand Up @@ -241,14 +247,30 @@ def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
global_cond_feats = [batch["observation.state"]]
# Extract image feature (first combine batch, sequence, and camera index dims).
if self._use_images:
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
if self.config.rgb_encoder_per_camera:
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
img_features_list = torch.cat(
[
encoder(images)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=False)
]
)
img_features = einops.rearrange(
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
print("using rgb encoder per camera")
print(img_features.shape)
else:
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
print("using single rgb encoder")
print(img_features.shape)
global_cond_feats.append(img_features)

if self._use_env_state:
Expand Down

0 comments on commit 38c883b

Please sign in to comment.