Skip to content

Commit

Permalink
allow setting model params in RCCarSimEnv
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasrothfuss committed Jul 7, 2023
1 parent 28e272a commit 28a4168
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions sim_transfer/sims/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,32 @@ class RCCarSimEnv:
_angle_idx: int = 2
_obs_noise_stds: jnp.array = 0.05 * jnp.exp(jnp.array([-3.3170326, -3.7336411, -2.7081904,
-2.7841284, -2.7067015, -1.4446207]))
_default_car_model_params: Dict = {
'use_blend': 0.0,
'm': 1.3,
'c_m_1': 1.0,
'c_m_2': 0.2,
'c_d': 0.5,
'l_f': 0.3,
'l_r': 0.3,
'steering_limit': 0.5
}

def __init__(self, ctrl_cost_weight: float = 0.005, encode_angle: bool = False, use_obs_noise: bool = True,
seed: int = 230492394):
car_model_params: Dict = None, seed: int = 230492394):
self.dim_state: Tuple[int] = (7,) if encode_angle else (6,)
self.encode_angle: bool = encode_angle
self._rds_key = jax.random.PRNGKey(seed)

# initialize dynamics and observation noise models
self._dynamics_model = RaceCar(dt=self._dt, encode_angle=False)
self._dynamics_params = CarParams(
use_blend=0.0,
m=1.3,
c_m_1=1.0,
c_m_2=0.2,
c_d=0.5,
l_f=0.3,
l_r=0.3
) # TODO allow setting the params

if car_model_params is None:
_car_model_params = self._default_car_model_params
else:
_car_model_params = self._default_car_model_params
_car_model_params.update(car_model_params)
self._dynamics_params = CarParams(**_car_model_params)
self._next_step_fn = jax.jit(partial(self._dynamics_model.next_step, params=self._dynamics_params))

self.use_obs_noise = use_obs_noise
Expand Down Expand Up @@ -206,7 +214,3 @@ def time(self) -> float:
traj = jnp.stack(traj)

plot_rc_trajectory(traj, encode_angle=ENCODE_ANGLE)

from matplotlib import pyplot as plt
plt.plot(jnp.arange(len(rewards)), rewards)
plt.show()

0 comments on commit 28a4168

Please sign in to comment.