Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lenarttreven committed Aug 25, 2023
2 parents f22348c + 6542f22 commit a5b0f82
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import jax.tree_util
import numpy as np

from sim_transfer.hardware.car_env import CarEnv
from sim_transfer.hardware.xbox_data_recording.xboxagent import CarXbox2D
from brax.training.types import Transition
import pickle

BASE_SPEED = 0.5
RECORDING_NAME = 'test_recording.pickle'
RECORDING_NAME = 'test_1.pickle'


if __name__ == '__main__':
controller = CarXbox2D(base_speed=BASE_SPEED)
env = CarEnv()
controller = CarXbox2D(base_speed=1.0)
env = CarEnv(encode_angle=False)
obs, _ = env.reset()
stop = False
observations = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pickle
import jax.numpy as jnp
import numpy as np
from mbpo.optimizers.policy_optimizers.sac.sac_networks import SACNetworksModel, make_inference_fn
from brax.training.acme import running_statistics
import flax.linen as nn
import jax.random as jr
from sim_transfer.hardware.car_env import CarEnv

ENCODE_ANGLE = True

normalize_fn = running_statistics.normalize
sac_networks_model = SACNetworksModel(
x_dim=7, u_dim=2,
preprocess_observations_fn=normalize_fn,
policy_hidden_layer_sizes=(64, 64),
policy_activation=nn.swish,
critic_hidden_layer_sizes=(64, 64),
critic_activation=nn.swish)

inference_fn = make_inference_fn(sac_networks_model.get_sac_networks())

with open('params.pkl', 'rb') as file:
params = pickle.load(file)


def policy(obs):
dummy_obs = obs[0:7]
return np.asarray(inference_fn(params, deterministic=True)(dummy_obs, jr.PRNGKey(0))[0])

env = CarEnv(encode_angle=ENCODE_ANGLE)
obs, _ = env.reset()
for i in range(200):
action = policy(obs)
obs, reward, terminate, info = env.step(action)
print(obs)
env.close()
Original file line number Diff line number Diff line change
Expand Up @@ -4,62 +4,73 @@
import time

from sim_transfer.hardware.car_env import CarEnv
from sim_transfer.sims.util import plot_rc_trajectory

env = CarEnv()

ENCODE_ANGLE = False
env = CarEnv(encode_angle=ENCODE_ANGLE)

def simulate_system_response(duration=0.5, velocity_max=0.8, num_runs=1):

def simulate_system_response(steps=120, velocity_max=0.8, num_runs=1):
all_time_values = []
all_response_values = []
all_actions = []
for run in range(num_runs):
time_values = []
response_values = []

actions = []
time.sleep(5)

obs, _ = env.reset()
start_time = time.time()
elapsed_time = 0
step = 0
while elapsed_time < duration:
for step in range(steps):
current_time = time.time() - start_time
action = np.array([-1 * np.cos(step/30.0), velocity_max / (step/30.0 + 1)])
# action = np.array([0, 0.8])
t = time.time()
next_obs, reward, done, info = env.step(action)
print('time to set command', time.time() - t)
t = time.time()
response_values.append(obs)
response_values.append(obs[: 6 + int(ENCODE_ANGLE)])
print('time to get state', time.time() - t)
obs = next_obs
time_values.append(current_time)
elapsed_time = current_time
print(current_time)
print(action)
step += 1
actions.append(action)

all_time_values.append(time_values)
all_response_values.append(response_values)
all_actions.append(actions)
env.close()
return all_time_values, all_response_values
return all_time_values, all_response_values, all_actions


num_runs = 1
time_values, response_values = simulate_system_response(num_runs=num_runs)
time_values, response_values, all_actions = simulate_system_response(steps=50, num_runs=num_runs, velocity_max=0.5)
#
response_array = np.stack([np.array(response_value).reshape(-1, 6 + int(ENCODE_ANGLE)) for response_value in
response_values])

response_array = [np.array(response_value).reshape(-1, env.observation_space.shape[0]) for response_value in
response_values]
time_values = [np.array(time_value).reshape(-1, 1) for time_value in time_values]
num_dims = 6
for dim in range(num_dims):
plt.figure(figsize=(10, 8))
for run in range(num_runs):
plt.plot(time_values[run] * 1000, response_array[run][:, dim], label=f'Run {run + 1}')
plt.axhline(response_array[run][0, dim], color='red', linestyle='dashed', label='initial state')
plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(base=20))
plt.xlabel("Time (milli seconds)")
plt.ylabel(f"State {dim + 1}")
plt.title(f'Step response - Dimension {dim + 1}')
plt.legend()
plt.tight_layout()
plt.savefig(f'dimension_{dim + 1}_response.pdf')
plt.close()
actions = np.stack([np.array(response_value).reshape(-1, env.action_space.shape[0]) for response_value in
all_actions])

for i in range(num_runs):
plot_rc_trajectory(response_array[i], actions[i], encode_angle=ENCODE_ANGLE)
# time_values = [np.array(time_value).reshape(-1, 1) for time_value in time_values]
# num_dims = 6
# for dim in range(num_dims):
# plt.figure(figsize=(10, 8))
# for run in range(num_runs):
# plt.plot(time_values[run] * 1000, response_array[run][:, dim], label=f'Run {run + 1}')
# plt.axhline(response_array[run][0, dim], color='red', linestyle='dashed', label='initial state')
# plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(base=20))
# plt.xlabel("Time (milli seconds)")
# plt.ylabel(f"State {dim + 1}")
# plt.title(f'Step response - Dimension {dim + 1}')
# plt.legend()
# plt.tight_layout()
# plt.savefig(f'dimension_{dim + 1}_response.pdf')
# plt.close()
31 changes: 24 additions & 7 deletions sim_transfer/hardware/car_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from gym.spaces import Box
from typing import Optional

X_MIN_LIMIT = -2.55
X_MAX_LIMIT = 2.55
Y_MAX_LIMIT = 3.2
Y_MIN_LIMIT = -3.2
X_MIN_LIMIT = -3.2
X_MAX_LIMIT = 3.2
Y_MAX_LIMIT = -2.55
Y_MIN_LIMIT = 2.55


class CarEnv(gym.Env):
Expand All @@ -21,6 +21,7 @@ def __init__(self,
num_frame_stacks: int = 3,
port_number: int = 8, # leftmost usb port in the display has port number 8
encode_angle: bool = True,
max_throttle: float = 0.5,
goal: np.ndarray = np.asarray([0.0, 0.0, 0.0])
):
super().__init__()
Expand All @@ -42,6 +43,7 @@ def __init__(self,
self.max_steps = 200
self.env_steps = 0
high = np.ones(6 + self.encode_angle + 2 * num_frame_stacks) * np.inf
self.max_throttle = np.clip(max_throttle, 0.0, 1.0)
if self.encode_angle:
high[2:4] = 1
high[6:] = 1
Expand Down Expand Up @@ -76,6 +78,16 @@ def log_mocap_info(self):
writer.writeheader()
writer.writerow(logs_dictionary)

def get_state_from_mocap(self):
current_state = self.controller.get_state()
mocap_x = current_state[[0, 3]]
current_state[[0, 3]] = current_state[[1, 4]]
current_state[[1, 4]] = mocap_x
current_state[2] += np.pi
current_state = self.normalize_theta(current_state)
return current_state


def reset(
self,
*,
Expand All @@ -93,7 +105,7 @@ def reset(
print("Starting controller in ~5 sec")
time.sleep(5)
self.controller_started = True
current_state = self.controller.get_state()
current_state = self.get_state_from_mocap()
current_state[0:3] = current_state[0:3] - self.goal
if self.encode_angle:
new_state = self.get_encoded_state(current_state)
Expand All @@ -120,8 +132,13 @@ def get_encoded_state(self, true_state):
def step(self, action):
assert np.shape(action) == (2,)
self.controller.control_mode() # sets the mode to control
self.controller.set_command(action) # set action
next_state = self.controller.get_state() # get state
action = np.clip(action, -1.0, 1.0)
action[0] *= self.max_throttle
command_set_in_time = self.controller.set_command(action) # set action
assert command_set_in_time, "API blocked python thread for too long"
time_elapsed = self.controller.get_time_elapsed()
next_state = self.get_state_from_mocap() # get state
# next_state[[1, 4]] *= -1
next_state[0:3] = next_state[0:3] - self.goal
new_state = np.zeros_like(self.state)
# if desired, encode angle
Expand Down
14 changes: 9 additions & 5 deletions sim_transfer/sims/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import jax.numpy as jnp
import numpy as np

from sim_transfer.sims.dynamics_models import RaceCar, CarParams
from sim_transfer.sims.tolerance_reward import ToleranceReward
Expand Down Expand Up @@ -55,7 +56,7 @@ class RCCarSimEnv:
_dt: float = 1 / 30.
dim_action: Tuple[int] = (2,)
_goal: jnp.array = jnp.array([0.0, 0.0, - jnp.pi / 2.])
_init_pose: jnp.array = jnp.array([-1.04, -1.42, jnp.pi / 2.])
_init_pose: jnp.array = jnp.array([1.04, -1.42, jnp.pi])
_angle_idx: int = 2
_obs_noise_stds: jnp.array = 0.05 * jnp.exp(jnp.array([-3.3170326, -3.7336411, -2.7081904,
-2.7841284, -2.7067015, -1.4446207]))
Expand Down Expand Up @@ -103,7 +104,7 @@ class RCCarSimEnv:

def __init__(self, ctrl_cost_weight: float = 0.005, encode_angle: bool = False, use_obs_noise: bool = True,
use_tire_model: bool = False, action_delay: float = 0.0, car_model_params: Dict = None,
seed: int = 230492394):
max_throttle: float = 0.5, seed: int = 230492394):
"""
Race car simulator environment
Expand All @@ -119,6 +120,7 @@ def __init__(self, ctrl_cost_weight: float = 0.005, encode_angle: bool = False,
self.dim_state: Tuple[int] = (7,) if encode_angle else (6,)
self.encode_angle: bool = encode_angle
self._rds_key = jax.random.PRNGKey(seed)
self.max_throttle = jnp.clip(max_throttle, 0.0, 1.0)

# initialize dynamics and observation noise models
self._dynamics_model = RaceCar(dt=self._dt, encode_angle=False)
Expand Down Expand Up @@ -187,6 +189,8 @@ def step(self, action: jnp.array, rng_key: Optional[jax.random.PRNGKey] = None)
"""

assert action.shape[-1:] == self.dim_action
action = np.clip(action, -1.0, 1.0)
action = action.at[0].set(self.max_throttle * action[0])
# assert jnp.all(-1 <= action) and jnp.all(action <= 1), "action must be in [-1, 1]"
rng_key = self.rds_key if rng_key is None else rng_key

Expand Down Expand Up @@ -247,9 +251,9 @@ def time(self) -> float:


if __name__ == '__main__':
ENCODE_ANGLE = True
ENCODE_ANGLE = False
env = RCCarSimEnv(encode_angle=ENCODE_ANGLE,
action_delay=0.07,
action_delay=0.00,
use_tire_model=True,
use_obs_noise=True)

Expand All @@ -258,7 +262,7 @@ def time(self) -> float:
traj = [s]
rewards = []
actions = []
for i in range(120):
for i in range(50):
t = i / 30.
a = jnp.array([- 1 * jnp.cos(1.0 * t), 0.8 / (t + 1)])
s, r, _, _ = env.step(a)
Expand Down
9 changes: 9 additions & 0 deletions sim_transfer/sims/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp
from typing import Optional

Expand Down Expand Up @@ -52,6 +53,14 @@ def plot_rc_trajectory(traj: jnp.array, actions: Optional[jnp.array] = None, pos
# axes[0][0].plot(traj[:, 0], traj[:, 1])
axes[0][0].set_title('x-y')
# Plot the velocity of the car as vectors
if isinstance(traj, jax.Array):
state_x = traj[:, [0, -3]]
traj = traj.at[:, [0, -3]].set(traj[:, [1, -2]])
traj = traj.at[:, [1, -2]].set(-state_x)
else:
state_x = traj[:, [0, -3]]
traj[:, [0, -3]] = traj[:, [1, -2]]
traj[:, [1, -2]] = -state_x
total_vel = jnp.sqrt(traj[:, 3] ** 2 + traj[:, 4] ** 2)
axes[0][0].quiver(traj[0:-1:3, 0], traj[0:-1:3, 1], traj[0:-1:3, 3], traj[0:-1:3, 4],
total_vel[0:-1:3], cmap='jet', scale=20,
Expand Down

0 comments on commit a5b0f82

Please sign in to comment.