Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] use one pass to compute mean and variance of recorded data #452

Open
tanjunyao7 opened this issue Sep 25, 2024 · 3 comments
Assignees

Comments

@tanjunyao7
Copy link

tanjunyao7 commented Sep 25, 2024

Hi,

first of all, thanks for the great work.

I recorded 50 episodes with a real robot with each episode lasting 20 seconds. When the recording is finished, the statistics of the data is computed for the normalization. However, the computation costs almost one hour. After investigating the code, I found that it iterates the data twice, first for the computation of mean, second for variance.

first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break

I believe both the mean and variance can be computed in a single pass, halving the total computation time. Are there any plan for this improvement?

@Cadene
Copy link
Collaborator

Cadene commented Sep 25, 2024

@tanjunyao7 Yes! it's on our todo list but we don't have the bandwidth as of now. If you have time could you please create a PR? That would be extremely helpful!!!

cc @michel-aractingi for visibility

@tanjunyao7
Copy link
Author

yes, I could create a PR. I'll close this issue.

@tanjunyao7
Copy link
Author

tanjunyao7 commented Sep 26, 2024

sorry I decided to paste the code here since I don't have time to write the test script. It's manually tested by computing the original result and the new result of the same data. Here is the code snippet:

first_batch = None
running_item_count = 0.0
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
        tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
    this_batch_size = len(batch["index"])

    if first_batch is None:
        first_batch = deepcopy(batch)
    for key, pattern in stats_patterns.items():
        batch_key = batch[key].float()
        batch_mean = einops.reduce(batch_key, pattern, "mean")
        batch_sq_mean = einops.reduce(batch_key**2, pattern, "mean")

        mean[key] = (running_item_count * mean[key] + this_batch_size * batch_mean) / (
                running_item_count + this_batch_size)

        #as of now it's the mean of squares, not std
        std[key] = (running_item_count * std[key] + this_batch_size * batch_sq_mean) / (
                running_item_count + this_batch_size)

        max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
        min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
    running_item_count += this_batch_size * 1.0
    if i == ceil(max_num_samples / batch_size) - 1:
        break

for key in stats_patterns.keys():
    std[key] = torch.sqrt(std[key] - mean[key]*mean[key])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants