From 5366f0ff80a1ff30a55fd3c47e7eb6f29ac20ca8 Mon Sep 17 00:00:00 2001 From: lenarttreven Date: Thu, 17 Aug 2023 12:20:56 +0200 Subject: [PATCH] running (but not performing well) model based rl loop --- .../rl/model_based_rl/learned_system.py | 11 +++-- sim_transfer/rl/model_based_rl/utils.py | 40 +++++++++++++++++++ sim_transfer/rl/race_car_sac.py | 3 +- sim_transfer/sims/car_system.py | 9 +---- 4 files changed, 52 insertions(+), 11 deletions(-) create mode 100644 sim_transfer/rl/model_based_rl/utils.py diff --git a/sim_transfer/rl/model_based_rl/learned_system.py b/sim_transfer/rl/model_based_rl/learned_system.py index 3c950d3..92cd909 100644 --- a/sim_transfer/rl/model_based_rl/learned_system.py +++ b/sim_transfer/rl/model_based_rl/learned_system.py @@ -34,10 +34,12 @@ def next_state(self, assert x.shape == (self.x_dim,) and u.shape == (self.u_dim,) # Create state-action pair z = jnp.concatenate([x, u]) + z = z.reshape((1, -1)) x_next_dist = self.model.predict_dist(z, include_noise=self.include_noise) next_key, key_sample_x_next = jr.split(dynamics_params.key) - x_next = x_next_dist.sample(seed=dynamics_params.key_sample_x_next) - new_dynamics_params = dynamics_params.update(key=next_key) + x_next = x_next_dist.sample(seed=key_sample_x_next) + x_next = x_next.reshape((self.x_dim,)) + new_dynamics_params = dynamics_params.replace(key=next_key) return Normal(loc=x_next, scale=jnp.zeros_like(x_next)), new_dynamics_params def init_params(self, key: chex.PRNGKey) -> DynamicsParams: @@ -45,7 +47,10 @@ def init_params(self, key: chex.PRNGKey) -> DynamicsParams: class LearnedCarSystem(System[DynamicsParams, CarRewardParams]): - def __init__(self, model, include_noise, **car_reward_kwargs): + def __init__(self, + model: BatchedNeuralNetworkModel, + include_noise: bool, + **car_reward_kwargs: dict): reward = CarReward(**car_reward_kwargs) dynamics = LearnedDynamics(x_dim=reward.x_dim, u_dim=reward.u_dim, model=model, include_noise=include_noise) System.__init__(self, dynamics=dynamics, reward=CarReward(**car_reward_kwargs)) diff --git a/sim_transfer/rl/model_based_rl/utils.py b/sim_transfer/rl/model_based_rl/utils.py new file mode 100644 index 0000000..8138dc6 --- /dev/null +++ b/sim_transfer/rl/model_based_rl/utils.py @@ -0,0 +1,40 @@ +import chex +import jax.numpy as jnp +from jax import random + + +def split_data(x: chex.Array, y: chex.Array, test_ratio=0.2, seed=0): + """ + Splits the data into training and test sets. + Parameters: + x (array): Input data. + y (array): Output data. + test_ratio (float): Fraction of the data to be used as test data. + seed (int): Seed for random number generator. + Returns: + x_train, x_test, y_train, y_test + """ + n = x.shape[0] + idx = jnp.arange(n) + rng = random.PRNGKey(seed) + permuted_idx = random.permutation(rng, idx) + test_size = int(n * test_ratio) + train_idx = permuted_idx[:-test_size] + test_idx = permuted_idx[-test_size:] + + x_train, x_test = x[train_idx], x[test_idx] + y_train, y_test = y[train_idx], y[test_idx] + + return x_train, x_test, y_train, y_test + + +if __name__ == "__main__": + x = jnp.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + y = jnp.array([1, 0, 1, 0, 1]) + + x_train, x_test, y_train, y_test = split_data(x, y, test_ratio=0.4, seed=42) + + print("x_train:", x_train) + print("x_test:", x_test) + print("y_train:", y_train) + print("y_test:", y_test) diff --git a/sim_transfer/rl/race_car_sac.py b/sim_transfer/rl/race_car_sac.py index 7160867..62fed4d 100644 --- a/sim_transfer/rl/race_car_sac.py +++ b/sim_transfer/rl/race_car_sac.py @@ -19,6 +19,7 @@ action_delay=0.00, use_tire_model=True, use_obs_noise=True, + ctrl_cost_weight=0.005, ) # Create replay buffer @@ -46,7 +47,7 @@ state = jit(env.reset)(rng=jr.PRNGKey(0)) -num_env_steps_between_updates = 16 +num_env_steps_between_updates = 4 num_envs = 32 sac_trainer = SAC( diff --git a/sim_transfer/sims/car_system.py b/sim_transfer/sims/car_system.py index f9d3d7a..930fd84 100644 --- a/sim_transfer/sims/car_system.py +++ b/sim_transfer/sims/car_system.py @@ -193,13 +193,10 @@ class CarRewardParams: class CarReward(Reward[CarRewardParams]): _goal: jnp.array = jnp.array([0.0, 0.0, - jnp.pi / 2.]) - def __init__(self, ctrl_cost_weight: float = 0.1, action_cost_weight: float = 0.0, encode_angle: bool = False): + def __init__(self, ctrl_cost_weight: float = 0.005, encode_angle: bool = False): Reward.__init__(self, x_dim=7 if encode_angle else 6, u_dim=2) self.ctrl_cost_weight = ctrl_cost_weight - self.action_cost_weight = action_cost_weight - self.encode_angle: bool = encode_angle - self._reward_model = RCCarEnvReward(goal=self._goal, ctrl_cost_weight=ctrl_cost_weight, encode_angle=self.encode_angle) @@ -219,8 +216,7 @@ def __call__(self, class CarSystem(System[CarDynamicsParams, CarRewardParams]): def __init__(self, encode_angle: bool = False, use_tire_model: bool = False, action_delay: float = 0.0, - car_model_params: Dict = None, ctrl_cost_weight: float = 0.1, action_cost_weight: float = 0.0, - use_obs_noise: bool = True): + car_model_params: Dict = None, ctrl_cost_weight: float = 0.005, use_obs_noise: bool = True): System.__init__(self, dynamics=CarDynamics(encode_angle=encode_angle, use_tire_model=use_tire_model, @@ -228,7 +224,6 @@ def __init__(self, encode_angle: bool = False, use_tire_model: bool = False, act car_model_params=car_model_params, use_obs_noise=use_obs_noise), reward=CarReward(ctrl_cost_weight=ctrl_cost_weight, - action_cost_weight=action_cost_weight, encode_angle=encode_angle) )