-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
67 lines (60 loc) · 1.64 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gym
import torch
import wandb
import numpy as np
from cic.agent import CICAgent
from cic.trainer import CICTrainer
from cic.utils import set_seed, NormalNoise, rollout
DEVICE = "cuda"
def main():
set_seed(seed=32)
wandb.init(
project="CIC",
group="cheetah",
name="first_run",
entity="Howuhh",
mode="disabled"
)
agent = CICAgent(
obs_dim=17,
action_dim=6,
skill_dim=64, # 64
hidden_dim=256, # 256
learning_rate=1e-4,
target_tau=1e-4, # 1e-4
update_actor_every=1,
device=DEVICE
)
exploration = NormalNoise(
action_dim=6,
timesteps=2_000_000,
max_action=1.0,
eps_max=0.6, # 0.6
eps_min=0.05
)
trainer = CICTrainer(
train_env="HalfCheetah-v3",
eval_env="HalfCheetah-v3",
checkpoints_path="cic_checkpoints"
)
trainer.train(
agent=agent,
exploration=exploration,
timesteps=2_000_000,
start_train=4000, # 4000
batch_size=2048, # 1024
buffer_size=1_000_000,
update_skill_every=100, # 100
update_every=2, # 1
eval_every=25_000
)
# agent = torch.load("cic_checkpoints/45695034-daec-440f-b4cf-d505ffd1b251/agent_250000.pt")
# skills = np.linspace(0.0, 1.0, 10)
#
# for i, skill_value in enumerate(skills):
# env = gym.make("HalfCheetah-v3")
# set_seed(env=env, seed=32)
# skill = np.zeros(agent.skill_dim) + skill_value
# rollout(env, agent, skill, render_path=f"videos/rollout_{i}.mp4", max_steps=100)
if __name__ == "__main__":
main()