Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aot_lambda does not work #2

Open
siddancha opened this issue Nov 20, 2023 · 1 comment
Open

aot_lambda does not work #2

siddancha opened this issue Nov 20, 2023 · 1 comment

Comments

@siddancha
Copy link

siddancha commented Nov 20, 2023

When running python_scripts/mlp_learn/sdf/robot_sdf.py under the current version of

def dist_grad_closest_aot(self, q):
        return self.aot_lambda(q)
        # return self.functorch_vjp(q)

I get the following error:

Weights loaded!
/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../functions/LinDS.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.q_goal = torch.tensor(q_goal)
/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/deprecated.py:73: UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.vjp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.vjp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html
  warn_deprecated('vjp')
Traceback (most recent call last):
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/standalonePlanar2d.py", line 223, in <module>
    main_int()
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/standalonePlanar2d.py", line 121, in main_int
    mppi = MPPI(q_0, q_f, dh_params, obs, dt, dt_H, N_traj, DS_ARRAY, dh_a, nn_model, 2)
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../functions/MPPI.py", line 71, in __init__
    _, _, _, _, _ = self.propagate()
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../functions/MPPI.py", line 113, in propagate
    distance, self.nn_grad = self.distance_repulsion_nn(q_prev, aot=True)
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../functions/MPPI.py", line 259, in distance_repulsion_nn
    nn_dist, nn_grad, nn_minidx = self.nn_model.dist_grad_closest_aot(nn_input[:, 0:-1])
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../../mlp_learn/sdf/robot_sdf.py", line 161, in dist_grad_closest_aot
    return self.aot_lambda(q)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3725, in returned_function
    compiled_fn = create_aot_dispatcher_function(
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3379, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 757, in inner
    flat_f_outs = f(*flat_f_args)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3525, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../../mlp_learn/sdf/robot_sdf.py", line 154, in functorch_vjp
    dists, vjp_fn = vjp(self.model.forward, points)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/deprecated.py", line 74, in vjp
    return _impl.vjp(func, *primals, has_aux=has_aux)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 267, in vjp
    return _vjp_with_argnums(func, *primals, has_aux=has_aux)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 38, in fn
    return f(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 294, in _vjp_with_argnums
    primals_out = func(*primals)
  File "/home/sancha/repos/OptimalModulationDS/python_scripts/ds_mppi/scripts/../../mlp_learn/sdf/network_macros_mod.py", line 143, in forward
    y = self.layers[0](x_nerf)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1250, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1376, in dispatch
    ) = self.validate_and_convert_non_fake_tensors(func, converter, args, kwargs)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1597, in validate_and_convert_non_fake_tensors
    args, kwargs = tree_map_only(
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 353, in tree_map_only
    return tree_map(map_only(ty)(fn), pytree)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 283, in tree_map
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 283, in <listcomp>
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 334, in inner
    return f(x)
  File "/home/sancha/repos/OptimalModulationDS/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1587, in validate
    raise Exception(
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.t.default(Parameter containing:
tensor([...], size=(256, 15), requires_grad=True))

When I switch to vjp:

def dist_grad_closest_aot(self, q):
        # return self.aot_lambda(q)
        return self.functorch_vjp(q)

it works! Why does aot_lambda not work? Should I continue using functorch_vjp? Thanks!

@erdisayar
Copy link

Do you know how do they obtain these trained parameters ?
https://github.com/epfl-lasa/OptimalModulationDS/tree/master/python_scripts/mlp_learn/models

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants