Skip to content

Commit

Permalink
gibson evaluation supports configurable policy selection and explorat…
Browse files Browse the repository at this point in the history
…ion thresholds
  • Loading branch information
naokiyokoyamabd committed Sep 11, 2023
1 parent 36e1e2e commit a8c9719
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions zsos/semexp_env/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from envs import make_vec_envs
from moviepy.editor import ImageSequenceClip

from zsos.semexp_env.semexp_policy import SemExpITMPolicyV3
from zsos.semexp_env.semexp_policy import SemExpITMPolicyV2, SemExpITMPolicyV3
from zsos.utils.img_utils import reorient_rescale_map, resize_images
from zsos.utils.log_saver import is_evaluated, log_episode
from zsos.utils.visualization import add_text_to_image
Expand All @@ -31,7 +31,7 @@ def main():
num_episodes = int(args.num_eval_episodes)
args.device = torch.device("cuda:0" if args.cuda else "cpu")

policy = SemExpITMPolicyV3(
policy_kwargs = dict(
text_prompt="Seems like there is a target_object ahead.",
pointnav_policy_path="data/pointnav_weights.pth",
depth_image_shape=(224, 224),
Expand All @@ -55,6 +55,19 @@ def main():
visualize=True,
)

exp_thresh = float(os.environ.get("EXPLORATION_THRESH", 0.0))
if exp_thresh > 0.0:
policy_cls = SemExpITMPolicyV3
policy_kwargs["exploration_thresh"] = exp_thresh
policy_kwargs["text_prompt"] = (
"Seems like there is a target_object ahead.|There is a lot of area to"
" explore ahead."
)
else:
policy_cls = SemExpITMPolicyV2

policy = policy_cls(**policy_kwargs)

torch.set_num_threads(1)
envs = make_vec_envs(args)
obs, infos = envs.reset()
Expand Down

0 comments on commit a8c9719

Please sign in to comment.