Skip to content

Commit

Permalink
running (but not performing well) model based rl loop
Browse files Browse the repository at this point in the history
  • Loading branch information
lenarttreven committed Aug 17, 2023
1 parent 6cb8033 commit 5366f0f
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 11 deletions.
11 changes: 8 additions & 3 deletions sim_transfer/rl/model_based_rl/learned_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,23 @@ def next_state(self,
assert x.shape == (self.x_dim,) and u.shape == (self.u_dim,)
# Create state-action pair
z = jnp.concatenate([x, u])
z = z.reshape((1, -1))
x_next_dist = self.model.predict_dist(z, include_noise=self.include_noise)
next_key, key_sample_x_next = jr.split(dynamics_params.key)
x_next = x_next_dist.sample(seed=dynamics_params.key_sample_x_next)
new_dynamics_params = dynamics_params.update(key=next_key)
x_next = x_next_dist.sample(seed=key_sample_x_next)
x_next = x_next.reshape((self.x_dim,))
new_dynamics_params = dynamics_params.replace(key=next_key)
return Normal(loc=x_next, scale=jnp.zeros_like(x_next)), new_dynamics_params

def init_params(self, key: chex.PRNGKey) -> DynamicsParams:
return DynamicsParams(key=key)


class LearnedCarSystem(System[DynamicsParams, CarRewardParams]):
def __init__(self, model, include_noise, **car_reward_kwargs):
def __init__(self,
model: BatchedNeuralNetworkModel,
include_noise: bool,
**car_reward_kwargs: dict):
reward = CarReward(**car_reward_kwargs)
dynamics = LearnedDynamics(x_dim=reward.x_dim, u_dim=reward.u_dim, model=model, include_noise=include_noise)
System.__init__(self, dynamics=dynamics, reward=CarReward(**car_reward_kwargs))
Expand Down
40 changes: 40 additions & 0 deletions sim_transfer/rl/model_based_rl/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import chex
import jax.numpy as jnp
from jax import random


def split_data(x: chex.Array, y: chex.Array, test_ratio=0.2, seed=0):
"""
Splits the data into training and test sets.
Parameters:
x (array): Input data.
y (array): Output data.
test_ratio (float): Fraction of the data to be used as test data.
seed (int): Seed for random number generator.
Returns:
x_train, x_test, y_train, y_test
"""
n = x.shape[0]
idx = jnp.arange(n)
rng = random.PRNGKey(seed)
permuted_idx = random.permutation(rng, idx)
test_size = int(n * test_ratio)
train_idx = permuted_idx[:-test_size]
test_idx = permuted_idx[-test_size:]

x_train, x_test = x[train_idx], x[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

return x_train, x_test, y_train, y_test


if __name__ == "__main__":
x = jnp.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = jnp.array([1, 0, 1, 0, 1])

x_train, x_test, y_train, y_test = split_data(x, y, test_ratio=0.4, seed=42)

print("x_train:", x_train)
print("x_test:", x_test)
print("y_train:", y_train)
print("y_test:", y_test)
3 changes: 2 additions & 1 deletion sim_transfer/rl/race_car_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
action_delay=0.00,
use_tire_model=True,
use_obs_noise=True,
ctrl_cost_weight=0.005,
)

# Create replay buffer
Expand Down Expand Up @@ -46,7 +47,7 @@

state = jit(env.reset)(rng=jr.PRNGKey(0))

num_env_steps_between_updates = 16
num_env_steps_between_updates = 4
num_envs = 32

sac_trainer = SAC(
Expand Down
9 changes: 2 additions & 7 deletions sim_transfer/sims/car_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,10 @@ class CarRewardParams:
class CarReward(Reward[CarRewardParams]):
_goal: jnp.array = jnp.array([0.0, 0.0, - jnp.pi / 2.])

def __init__(self, ctrl_cost_weight: float = 0.1, action_cost_weight: float = 0.0, encode_angle: bool = False):
def __init__(self, ctrl_cost_weight: float = 0.005, encode_angle: bool = False):
Reward.__init__(self, x_dim=7 if encode_angle else 6, u_dim=2)
self.ctrl_cost_weight = ctrl_cost_weight
self.action_cost_weight = action_cost_weight

self.encode_angle: bool = encode_angle

self._reward_model = RCCarEnvReward(goal=self._goal,
ctrl_cost_weight=ctrl_cost_weight,
encode_angle=self.encode_angle)
Expand All @@ -219,16 +216,14 @@ def __call__(self,

class CarSystem(System[CarDynamicsParams, CarRewardParams]):
def __init__(self, encode_angle: bool = False, use_tire_model: bool = False, action_delay: float = 0.0,
car_model_params: Dict = None, ctrl_cost_weight: float = 0.1, action_cost_weight: float = 0.0,
use_obs_noise: bool = True):
car_model_params: Dict = None, ctrl_cost_weight: float = 0.005, use_obs_noise: bool = True):
System.__init__(self,
dynamics=CarDynamics(encode_angle=encode_angle,
use_tire_model=use_tire_model,
action_delay=action_delay,
car_model_params=car_model_params,
use_obs_noise=use_obs_noise),
reward=CarReward(ctrl_cost_weight=ctrl_cost_weight,
action_cost_weight=action_cost_weight,
encode_angle=encode_angle)
)

Expand Down

0 comments on commit 5366f0f

Please sign in to comment.