Skip to content

Commit

Permalink
added support for objectnav using Gibson
Browse files Browse the repository at this point in the history
  • Loading branch information
naokiyokoyamabd committed Sep 8, 2023
1 parent 6c2203f commit 6bdb7b5
Show file tree
Hide file tree
Showing 14 changed files with 426 additions and 41 deletions.
9 changes: 8 additions & 1 deletion scripts/parse_jsons.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def read_json_files(directory: str) -> List[Dict[str, Any]]:
episode_stats = []
for filename in os.listdir(directory):
if filename.endswith(".json"):
# Ignore empty files
if os.path.getsize(os.path.join(directory, filename)) == 0:
continue
with open(os.path.join(directory, filename), "r") as f:
episode_stats.append(json.load(f))
return episode_stats
Expand Down Expand Up @@ -97,7 +100,7 @@ def calculate_avg_fail_per_category(stats: List[Dict[str, Any]]) -> None:
# Add each row to the table
for category, stats in sorted(
category_stats.items(),
key=lambda x: (x[1]["fail_count"] / x[1]["total_count"]),
key=lambda x: x[1]["fail_count"],
reverse=True,
):
avg_failure_rate = (stats["fail_count"] / stats["total_count"]) * 100
Expand Down Expand Up @@ -159,6 +162,7 @@ def main() -> None:
"""
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("directory", type=str, help="Directory to process")
parser.add_argument("--compact", "-c", action="store_true", help="Compact output")
args = parser.parse_args()

episode_stats = read_json_files(args.directory)
Expand All @@ -170,6 +174,9 @@ def main() -> None:
print()
calculate_avg_performance(episode_stats)

if args.compact:
return

print()
calculate_avg_fail_per_category(episode_stats)

Expand Down
2 changes: 1 addition & 1 deletion zsos/policy/base_objectnav_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from habitat_baselines.common.tensor_dict import TensorDict

from zsos.policy.base_policy import BasePolicy
except ModuleNotFoundError:
except Exception:

class BasePolicy:
pass
Expand Down
2 changes: 1 addition & 1 deletion zsos/policy/itm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

try:
from habitat_baselines.common.tensor_dict import TensorDict
except ModuleNotFoundError:
except Exception:
pass

PROMPT_SEPARATOR = "|"
Expand Down
26 changes: 21 additions & 5 deletions zsos/policy/utils/non_habitat_policy/nh_pointnav_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ def forward(


class PointNavResNetNet(nn.Module):
def __init__(self):
def __init__(self, discrete_actions: bool = False, no_fwd_dict: bool = False):
super().__init__()
self.prev_action_embedding = nn.Linear(
in_features=2, out_features=32, bias=True
)
if discrete_actions:
self.prev_action_embedding = nn.Embedding(4 + 1, 32)
else:
self.prev_action_embedding = nn.Linear(
in_features=2, out_features=32, bias=True
)
self.tgt_embeding = nn.Linear(in_features=3, out_features=32, bias=True)
self.visual_encoder = ResNetEncoder()
self.visual_fc = nn.Sequential(
Expand All @@ -58,6 +61,8 @@ def __init__(self):
)
self.state_encoder = LSTMStateEncoder(576, 512, 2)
self.num_recurrent_layers = self.state_encoder.num_recurrent_layers
self.discrete_actions = discrete_actions
self.no_fwd_dict = no_fwd_dict

def forward(
self,
Expand All @@ -84,7 +89,15 @@ def forward(

x.append(self.tgt_embeding(goal_observations))

prev_actions = self.prev_action_embedding(masks * prev_actions.float())
if self.discrete_actions:
prev_actions = prev_actions.squeeze(-1)
start_token = torch.zeros_like(prev_actions)
# The mask means the previous action will be zero, an extra dummy action
prev_actions = self.prev_action_embedding(
torch.where(masks.view(-1), prev_actions + 1, start_token)
)
else:
prev_actions = self.prev_action_embedding(masks * prev_actions.float())

x.append(prev_actions)

Expand All @@ -93,6 +106,9 @@ def forward(
out, rnn_hidden_states, masks, rnn_build_seq_info
)

if self.no_fwd_dict:
return out, rnn_hidden_states # type: ignore

return out, rnn_hidden_states, {}


Expand Down
66 changes: 54 additions & 12 deletions zsos/policy/utils/pointnav_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,32 @@
from gym.spaces import Discrete
from torch import Tensor

habitat_version = ""

try:
from habitat_baselines.common.tensor_dict import TensorDict
import habitat
from habitat_baselines.rl.ddppo.policy import PointNavResNetPolicy
from habitat_baselines.rl.ppo.policy import PolicyActionData

class PointNavResNetTensorOutputPolicy(PointNavResNetPolicy):
def act(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
policy_actions: "PolicyActionData" = super().act(*args, **kwargs)
return policy_actions.actions, policy_actions.rnn_hidden_states
habitat_version = habitat.__version__

if habitat_version == "0.1.5":
print("Using habitat 0.1.5; assuming SemExp code is being used")

class PointNavResNetTensorOutputPolicy(PointNavResNetPolicy):
def act(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
value, action, action_log_probs, rnn_hidden_states = super().act(
*args, **kwargs
)
return action, rnn_hidden_states

else:
from habitat_baselines.common.tensor_dict import TensorDict
from habitat_baselines.rl.ppo.policy import PolicyActionData

class PointNavResNetTensorOutputPolicy(PointNavResNetPolicy):
def act(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
policy_actions: "PolicyActionData" = super().act(*args, **kwargs)
return policy_actions.actions, policy_actions.rnn_hidden_states

HABITAT_BASELINES_AVAILABLE = True
except ModuleNotFoundError:
Expand Down Expand Up @@ -121,8 +138,6 @@ def load_pointnav_policy(file_path: str) -> PointNavResNetTensorOutputPolicy:
Returns:
PointNavResNetTensorOutputPolicy: The policy.
"""
ckpt_dict = torch.load(file_path, map_location="cpu")

if HABITAT_BASELINES_AVAILABLE:
obs_space = SpaceDict(
{
Expand All @@ -138,13 +153,40 @@ def load_pointnav_policy(file_path: str) -> PointNavResNetTensorOutputPolicy:
}
)
action_space = Discrete(4)
pointnav_policy = PointNavResNetTensorOutputPolicy.from_config(
ckpt_dict["config"], obs_space, action_space
)
pointnav_policy.load_state_dict(ckpt_dict["state_dict"])
if habitat_version == "0.1.5":
pointnav_policy = PointNavResNetTensorOutputPolicy(
obs_space,
action_space,
hidden_size=512,
num_recurrent_layers=2,
rnn_type="LSTM",
resnet_baseplanes=32,
backbone="resnet18",
normalize_visual_inputs=False,
obs_transform=None,
)
# Need to overwrite the visual encoder because it uses an older version of
# ResNet that calculates the compression size differently
from zsos.policy.utils.non_habitat_policy.nh_pointnav_policy import (
PointNavResNetNet,
)

# print(pointnav_policy)
pointnav_policy.net = PointNavResNetNet(
discrete_actions=True, no_fwd_dict=True
)
state_dict = torch.load(file_path + ".state_dict", map_location="cpu")
else:
ckpt_dict = torch.load(file_path, map_location="cpu")
pointnav_policy = PointNavResNetTensorOutputPolicy.from_config(
ckpt_dict["config"], obs_space, action_space
)
state_dict = ckpt_dict["state_dict"]
pointnav_policy.load_state_dict(state_dict)
return pointnav_policy

else:
ckpt_dict = torch.load(file_path, map_location="cpu")
pointnav_policy = PointNavResNetTensorOutputPolicy()
current_state_dict = pointnav_policy.state_dict()
pointnav_policy.load_state_dict(
Expand Down
137 changes: 133 additions & 4 deletions zsos/semexp_env/eval.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import os
from typing import Any, Dict, List, Tuple

import cv2
import numpy as np
import torch
from arguments import get_args
from envs import make_vec_envs
from moviepy.editor import ImageSequenceClip

from zsos.semexp_env.semexp_policy import SemExpITMPolicyV3
from zsos.utils.img_utils import reorient_rescale_map, resize_images

os.environ["OMP_NUM_THREADS"] = "1"

args = get_args()
args.agent = "zsos" # Doesn't really matter as long as it's not "sem_exp"
args.split = "val"
args.task_config = "objnav_gibson_zsos.yaml"

np.random.seed(args.seed)
torch.manual_seed(args.seed)
Expand All @@ -22,23 +29,145 @@ def main():
num_episodes = int(args.num_eval_episodes)
args.device = torch.device("cuda:0" if args.cuda else "cpu")

policy = SemExpITMPolicyV3(
text_prompt="Seems like there is a target_object ahead.",
pointnav_policy_path="data/pointnav_weights.pth",
depth_image_shape=(224, 224),
det_conf_threshold=0.8,
pointnav_stop_radius=0.9,
use_max_confidence=False,
object_map_erosion_size=5,
exploration_thresh=0.0,
obstacle_map_area_threshold=1.5, # in square meters
min_obstacle_height=0.61,
max_obstacle_height=0.88,
hole_area_thresh=100000,
use_vqa=False,
vqa_prompt="Is this ",
coco_threshold=0.8,
non_coco_threshold=0.4,
camera_height=0.88,
min_depth=0.5,
max_depth=5.0,
camera_fov=79,
image_width=640,
visualize=True,
)

torch.set_num_threads(1)
envs = make_vec_envs(args)
obs, infos = envs.reset()
print(obs, infos)

for ep_num in range(num_episodes):
vis_imgs = []
for step in range(args.max_episode_length):
action = torch.randint(1, 3, (args.num_processes,))
obs_dict = merge_obs_infos(obs, infos)
if step == 0:
masks = torch.zeros(1, 1, device=obs.device)
else:
masks = torch.ones(1, 1, device=obs.device)
action, policy_infos = policy.act(obs_dict, masks)

if "VIDEO_DIR" in os.environ:
vis_imgs.append(create_frame(policy_infos))

action = action.squeeze(0)

obs, rew, done, infos = envs.step(action)
print(obs.shape)
print(obs.device)

if done:
print("Success:", infos[0]["success"])
print("SPL:", infos[0]["spl"])
if "VIDEO_DIR" in os.environ:
generate_video(vis_imgs, infos[0])
break

print("Test successfully completed")


def merge_obs_infos(
obs: torch.Tensor, infos: Tuple[Dict, ...]
) -> Dict[str, torch.Tensor]:
"""Merge the observations and infos into a single dictionary."""
rgb = obs[:, :3, ...].permute(0, 2, 3, 1)
depth = obs[:, 3:4, ...].permute(0, 2, 3, 1)
info_dict = infos[0]

def tensor_from_numpy(
tensor: torch.Tensor, numpy_array: np.ndarray
) -> torch.Tensor:
device = tensor.device
new_tensor = torch.from_numpy(numpy_array).to(device)
return new_tensor

obs_dict = {
"rgb": rgb,
"depth": depth,
"objectgoal": info_dict["goal_name"].replace("-", " "),
"gps": tensor_from_numpy(obs, info_dict["gps"]).unsqueeze(0),
"compass": tensor_from_numpy(obs, info_dict["compass"]).unsqueeze(0),
"heading": tensor_from_numpy(obs, info_dict["heading"]).unsqueeze(0),
}

return obs_dict


def create_frame(policy_infos: Dict[str, Any]) -> np.ndarray:
vis_imgs = []
for k in ["annotated_rgb", "annotated_depth", "obstacle_map", "value_map"]:
img = policy_infos[k]
if "map" in k:
img = reorient_rescale_map(img)
if k == "annotated_depth" and np.array_equal(img, np.ones_like(img) * 255):
# Put text in the middle saying "Target not curently detected"
text = "Target not currently detected"
text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 1, 1)[0]
cv2.putText(
img,
text,
(img.shape[1] // 2 - text_size[0] // 2, img.shape[0] // 2),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 0),
1,
)
vis_imgs.append(img)
vis_img = np.hstack(resize_images(vis_imgs, match_dimension="height"))
return vis_img


def generate_video(frames: List[np.ndarray], infos: Dict[str, Any]) -> None:
"""
Saves the given list of rgb frames as a video at 10 FPS. Uses the infos to get the
files name, which should contain the following:
- episode_id
- scene_id
- success
- spl
- dtg
- goal_name
"""
video_dir = os.environ.get("VIDEO_DIR", "video_dir")
if not os.path.exists(video_dir):
os.makedirs(video_dir)
episode_id = int(infos["episode_id"])
scene_id = infos["scene_id"]
success = int(infos["success"])
spl = infos["spl"]
dtg = infos["distance_to_goal"]
goal_name = infos["goal_name"]
filename = (
f"epid={episode_id:03d}-scid={scene_id}-succ={success}-spl={spl:.2f}"
f"-dtg={dtg:.2f}-goal={goal_name}.mp4"
)
filename = os.path.join(video_dir, filename)
# Create a video clip from the frames
clip = ImageSequenceClip(frames, fps=10)

# Write the video file
clip.write_videofile(filename)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion zsos/semexp_env/objnav_gibson_zsos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ SIMULATOR:
TASK:
TYPE: ObjectNav-v1
POSSIBLE_ACTIONS: ["STOP", "MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT", "LOOK_UP", "LOOK_DOWN"]
SENSORS: ['GPS_SENSOR', 'COMPASS_SENSOR']
SENSORS: ['GPS_SENSOR', 'COMPASS_SENSOR', 'HEADING_SENSOR']
MEASUREMENTS: ['DISTANCE_TO_GOAL', 'SUCCESS', 'SPL']
SUCCESS:
SUCCESS_DISTANCE: 0.2
Expand Down
Loading

0 comments on commit 6bdb7b5

Please sign in to comment.