Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect shape for vae_encoder_input for ACT #194

Closed
1 of 2 tasks
radekosmulski opened this issue May 17, 2024 · 2 comments
Closed
1 of 2 tasks

Incorrect shape for vae_encoder_input for ACT #194

radekosmulski opened this issue May 17, 2024 · 2 comments
Assignees
Labels
🗃️ Dataset Something dataset-related 🧠 Policies Something policies-related

Comments

@radekosmulski
Copy link
Contributor

System Info

lerobot main

Information

  • One of the scripts in the examples/ folder of LeRobot
  • My own task or dataset (give details below)

Reproduction

import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.datasets.utils import cycle

from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.act.modeling_act import ACTPolicy

import torch
import multiprocessing as mp

from aim import Run

config = ACTConfig()

delta_timestamps = {
    "observation.images.top": [0],
    "observation.state": [0],
    "action": [t / 50 for t in range(config.chunk_size)],  # this dataset was recorded at 50Hz
}
dataset = LeRobotDataset('lerobot/aloha_sim_transfer_cube_scripted', split='train', delta_timestamps=delta_timestamps)

dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=mp.cpu_count()-1,
    batch_size=8,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
)

policy = ACTPolicy(config=config, dataset_stats=dataset.stats)

for batch in dataloader:
    output_dict = policy.forward(batch)

Expected behavior

Error:

File ~/workspace/lerobot/lerobot/common/policies/act/modeling_act.py:300, in ACT.forward(self, batch)
    296 robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
    297     1
    298 )  # (B, 1, D)
    299 action_embed = self.vae_encoder_action_input_proj(batch["action"])  # (B, S, D)
--> 300 vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1)  # (B, S+2, D)
    302 # Prepare fixed positional embedding.
    303 # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
    304 pos_embed = self.vae_encoder_pos_enc.clone().detach()  # (1, S+2, D)

RuntimeError: Tensors must have same number of dimensions: got 3 and 4

Here are the shapes:

ipdb>  cls_embed.shape
torch.Size([8, 1, 512])
ipdb>  robot_state_embed.shape
torch.Size([8, 1, 1, 512])
ipdb>  action_embed.shape
torch.Size([8, 100, 512])

That is a trivial fix here but not sure how it is possible that I am getting this error message? I just saw some commits related to ACT merged into main -- I suppose people would be seeing this issue as well?

Anyhow -- getting late, will call it a day. Would appreciate if another set of eyes could take a look at this and confirm whether this indeed is a genuine issue.

If it is -- might be good to add an example for training ACT. Can work on adding one if there would be interest. Or maybe just a unit test? That might be even better as I suppose there is not much new that an example would introduce on top of what's already in the training example with DiffusionPolicy (just different delta_timestamps and normalization, maybe having normalization is reason enough for another example?).

@alexander-soare
Copy link
Collaborator

alexander-soare commented May 17, 2024

Thanks a lot for reporting the issue and making it easy to reproduce!

This happens because you have provided delta_timestamps for the observation keys, and that adds a sequence dimension which was not expected by ACT (which currently only handles one observation). So if you do:

delta_timestamps = {
    # "observation.images.top": [0],
    # "observation.state": [0],
    "action": [t / 50 for t in range(config.chunk_size)],  # this dataset was recorded at 50Hz
}

you'll be good to run the training.

I'm not surprised this caught you out though and I'll need to think about how to make it more intuitive. We have some changes in the works #163, so when I get back to it, I'll consider this issue.

@alexander-soare alexander-soare added 🧠 Policies Something policies-related 🗃️ Dataset Something dataset-related labels May 17, 2024
@radekosmulski
Copy link
Contributor Author

Thank you @alexander-soare! Appreciate the explanation 🙏 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🗃️ Dataset Something dataset-related 🧠 Policies Something policies-related
Projects
None yet
Development

No branches or pull requests

2 participants