Skip to content

Commit

Permalink
add visualizations
Browse files Browse the repository at this point in the history
  • Loading branch information
tsilver-bdai committed Jul 28, 2023
1 parent 13248a5 commit 3b39799
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 39 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ machines.txt
tests/_fake_trajs
tests/_fake_results
predicators/envs/assets/task_jsons/spot_bike_env/last.json
spot_perception_outputs

# Jetbrains IDEs
.idea/
1 change: 1 addition & 0 deletions predicators/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class GlobalSettings:
spot_fiducial_size = 44.45
spot_visualize_vision_model_outputs = False
spot_vision_detection_threshold = 0.30
spot_perception_outdir = "spot_perception_outputs"

# pddl blocks env parameters
pddl_blocks_procedural_train_min_num_blocks = 3
Expand Down
2 changes: 1 addition & 1 deletion predicators/spot_utils/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ The pipeline is as follows:
- It is recommended to run on a local computer (faster connection) with CUDA GPU (faster inference)
- Connect to server from local
- Use SSH "local port forward"
- `ssh -L 5550:localhost:5550 <IP-ADDRESS>`
- `ssh -L 5550:localhost:5550 10.17.1.102`
- Request from your local computer
- You can see perception_utils.py, or the `client.py` function in the BDAI repo.
120 changes: 82 additions & 38 deletions predicators/spot_utils/perception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

import io
import math
import os
import time
from pathlib import Path
from typing import Dict, List, Tuple

import bosdyn.client
import bosdyn.client.util
import dill as pkl
import imageio.v2 as iio
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -17,6 +21,7 @@
from numpy.typing import NDArray
from scipy import ndimage

from predicators import utils
from predicators.settings import CFG
from predicators.structs import Image

Expand Down Expand Up @@ -54,8 +59,13 @@ def image_to_bytes(img: PIL.Image.Image) -> io.BytesIO:
return buf


def visualize_output(im: PIL.Image.Image, masks: NDArray, input_boxes: NDArray,
classes: NDArray, scores: NDArray) -> None:
def visualize_output(im: PIL.Image.Image,
masks: NDArray,
input_boxes: NDArray,
classes: NDArray,
scores: NDArray,
prefix: str,
plot: bool = False) -> None:
"""Visualizes the output of SAM; useful for debugging.
masks, input_boxes, and scores come from the output of SAM.
Expand All @@ -64,7 +74,7 @@ def visualize_output(im: PIL.Image.Image, masks: NDArray, input_boxes: NDArray,
bounding box), classes is an array of strings, and scores is an
array of floats corresponding to confidence values.
"""
plt.figure(figsize=(10, 10))
fig = plt.figure(figsize=(10, 10))
plt.imshow(im)
for mask in masks:
show_mask(mask, plt.gca(), random_color=True)
Expand All @@ -81,7 +91,8 @@ def visualize_output(im: PIL.Image.Image, masks: NDArray, input_boxes: NDArray,
edgecolor='green',
alpha=0.5))
plt.axis('off')
plt.show()
img = utils.fig2data(fig, dpi=150)
_save_spot_perception_output(img, prefix, plot=plot)


def show_mask(mask: NDArray,
Expand Down Expand Up @@ -170,11 +181,16 @@ def query_detic_sam(rgb_image_dict_in: Dict[str, Image], classes: List[str],
"scores": curr_scores
}

if viz:
image = PIL.Image.fromarray(rgb_image_dict_in[source_name])
# Optional visualization useful for debugging.
visualize_output(image, d["masks"], d["boxes"], d["classes"],
d["scores"])
image = PIL.Image.fromarray(rgb_image_dict_in[source_name])
# Optional visualization useful for debugging.
prefix = f"detic_sam_{source_name}_raw_outputs"
visualize_output(image,
d["masks"],
d["boxes"],
d["classes"],
d["scores"],
prefix,
plot=viz)

# Filter out detections by confidence. We threshold detections
# at a set confidence level minimum, and if there are multiple
Expand Down Expand Up @@ -336,10 +352,15 @@ def get_pixel_locations_with_detic_sam(
x1, y1, x2, y2 = res_segment[camera_name]['boxes'][0].squeeze()
x_c = (x1 + x2) / 2
y_c = (y1 + y2) / 2
# Plot center and segmentation mask
if plot:
plt.imshow(res_segment[camera_name]['masks'][0][0].squeeze())
plt.show()

# Save/plot center and segmentation mask
bool_segmentation_img = res_segment[camera_name]['masks'][0][0].squeeze()
debug_img = 255 * bool_segmentation_img.astype(np.uint8)
_save_spot_perception_output(
debug_img,
prefix=f"detic_sam_{camera_name}_{obj_class}_segmentation",
plot=plot)

pixel_locations.append((x_c, y_c))

return pixel_locations
Expand Down Expand Up @@ -378,10 +399,11 @@ def get_object_locations_with_detic_sam(
rotated_rgb_image_dict[source_name] = rotated_rgb
rotated_depth_image_dict[source_name] = rotated_depth

# Plot the rotated image before querying DETIC-SAM.
if plot:
plt.imshow(rotated_rgb)
plt.show()
# Save/plot the rotated image before querying DETIC-SAM.
_save_spot_perception_output(
rotated_rgb,
prefix=f"detic_sam_{source_name}_object_locs_inputs",
plot=plot)

# Start by querying the DETIC-SAM model.
deticsam_results_all_cameras = query_detic_sam(
Expand Down Expand Up @@ -439,27 +461,32 @@ def get_object_locations_with_detic_sam(
# object that comes from SAM, and (5) the centroid
# after we rotate it back to align with the original
# RGB image.
if plot:
inverse_rotation_angle = -ROTATION_ANGLE[source_name]
plt.imshow(depth_image_dict[source_name])
plt.imshow(ndimage.rotate(
curr_res_segment['masks'][i][0].squeeze(),
inverse_rotation_angle,
reshape=False),
alpha=0.5,
cmap='Reds')
plt.scatter(x=x_c_rotated,
y=y_c_rotated,
marker='*',
color='red',
zorder=3)
plt.scatter(x=center[0],
y=center[1],
marker='.',
color='blue',
zorder=3)
plt.scatter(x=x_c, y=y_c, marker='*', color='green', zorder=3)
plt.show()
inverse_rotation_angle = -ROTATION_ANGLE[source_name]
fig = plt.figure()
plt.imshow(depth_image_dict[source_name])
plt.imshow(ndimage.rotate(
curr_res_segment['masks'][i][0].squeeze(),
inverse_rotation_angle,
reshape=False),
alpha=0.5,
cmap='Reds')
plt.scatter(x=x_c_rotated,
y=y_c_rotated,
marker='*',
color='red',
zorder=3)
plt.scatter(x=center[0],
y=center[1],
marker='.',
color='blue',
zorder=3)
plt.scatter(x=x_c, y=y_c, marker='*', color='green', zorder=3)
debug_img = utils.fig2data(fig, dpi=150)
_save_spot_perception_output(
debug_img,
prefix=
f"detic_sam_{source_name}_{obj_class}_object_locs_outputs",
plot=plot)

# Get XYZ of the point at center of bounding box and median depth
# value.
Expand All @@ -474,3 +501,20 @@ def get_object_locations_with_detic_sam(
x0, y0, z0)

return ret_camera_to_obj_positions


def _save_spot_perception_output(img: Image,
prefix: str,
plot: bool = False) -> None:
if plot:
plt.close()
plt.figure()
plt.axis("off")
plt.imshow(img)
plt.show()
# Save image for debugging.
time_str = time.strftime("%Y%m%d-%H%M%S")
filename = f"{prefix}_{time_str}.png"
outfile = Path(CFG.spot_perception_outdir) / filename
os.makedirs(CFG.spot_perception_outdir, exist_ok=True)
iio.imsave(outfile, img)

0 comments on commit 3b39799

Please sign in to comment.