Skip to content

Commit

Permalink
Add first pass CI and small unit test (#2)
Browse files Browse the repository at this point in the history
- Add pre-commit CI check
- Updated some dependencies in pyproject.toml (incomplete)
- Add test for habitat
- Fixed a few things pre-commit flagged

The pip install step currently takes ~30 min, which is too long.
Pre-building a container would take this down to 3 min.

---------

Co-authored-by: Naoki Yokoyama <nyokoyama@theaiinstitute.com>
  • Loading branch information
jiuguangw and naokiyokoyamabd authored Jul 25, 2023
1 parent 17e04fc commit d6cd20e
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 23 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/pre_commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Pre-Commit

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: 3.9.16
- uses: pre-commit/action@v3.0.0
30 changes: 30 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright [2023] Boston Dynamics AI Institute, Inc.

name: ZSOS - Main Build

on:
push:
branches: [ main ]
pull_request:

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
python-version: ['3.9.16']
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install package
run: |
sudo apt-get install -y libgl1-mesa-dev
pip install -e .[dev]
- name: Pytest
run: |
pytest test
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
rev: 23.3.0
hooks:
- id: black
language_version: python3.10
language_version: python3.9
args: ['--config', 'pyproject.toml']
exclude: 'dreamerv3/.*|grpc_infra/.*'
verbose: true
Expand Down
22 changes: 15 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@ name = "zsos"
version = "0.1"
description = "Zero shot object search"
authors = [
{name = "Naoki Yokoyama", email = "naokiyokoyama@github"},
{name = "Naoki Yokoyama", email = "nyokoyama@theaiinstitute.com"},
]
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.9"
dependencies = [
"torch >= 1.13.1",
# "habitat @ git+https://github.com/facebookresearch/habitat-sim.git",
"habitat-sim @ git+https://github.com/facebookresearch/habitat-sim.git",
"habitat-baselines >= 0.2.4",
"habitat-lab",
"frontier_exploration @ git+https://github.com/naokiyokoyama/frontier_exploration.git",
"transformers == 4.28.0", # higher versions break BLIP-2
"flask >= 2.3.2"
]

[project.optional-dependencies]
Expand All @@ -29,6 +34,9 @@ dev = [
"Homepage" = "theaiinstitute.com"
"GitHub" = "https://github.com/bdaiinstitute/llm-object-search"

[tool.setuptools]
packages = ["zsos", "config"]

[tool.ruff]
# Enable pycodestyle (`E`), Pyflakes (`F`), and import sorting (`I`)
select = ["E", "F", "I"]
Expand Down Expand Up @@ -69,8 +77,8 @@ line-length = 120
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

# Assume Python 3.10.
target-version = "py310"
# Assume Python 3.9.
target-version = "py39"

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
Expand All @@ -81,7 +89,7 @@ max-complexity = 10

[tool.black]
line-length = 88
target-version = ['py310']
target-version = ['py39']
include = '\.pyi?$'
# `extend-exclude` is not honored when `black` is passed a file path explicitly,
# as is typical when `black` is invoked via `pre-commit`.
Expand All @@ -95,7 +103,7 @@ preview = true

# mypy configuration
[tool.mypy]
python_version = "3.10"
python_version = "3.9"
disallow_untyped_defs = true
ignore_missing_imports = true
explicit_package_bases = true
Expand Down
2 changes: 0 additions & 2 deletions scripts/eval_llm_policy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# Copyright [2023] Boston Dynamics AI Institute, Inc.

python -um zsos.run \
--config-name=experiments/llm_objectnav_hm3d.yaml \
--config-path ../config \
habitat_baselines.evaluate=True \
habitat_baselines.eval_ckpt_path_dir=dummy_policy.pth \
habitat_baselines.load_resume_state_config=False \
Expand Down
28 changes: 28 additions & 0 deletions test/test_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os

import torch
from habitat_baselines.common.baseline_registry import baseline_registry # noqa

from zsos import get_config


def test_load_and_save_config():
if not os.path.exists("build"):
os.makedirs("build")

# Save a dummy state_dict using torch.save
config = get_config("config/experiments/llm_objectnav_hm3d.yaml")
dummy_dict = {
"config": config,
"extra_state": {"step": 0},
"state_dict": {},
}

filename = "build/dummy_policy.pth"
torch.save(dummy_dict, filename)

# Get the file size of the output PDF
file_size = os.path.getsize(filename)

# Check the size is greater than 30 KB
assert file_size > 30 * 1024, "Test failed - failed to create pth"
25 changes: 25 additions & 0 deletions test/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os

import cv2

from zsos.utils.visualization import generate_text_image


def test_visualization():
if not os.path.exists("build"):
os.makedirs("build")

width = 400
text = (
"This is a long text that needs to be drawn on an image with a specified "
"width. The text should wrap around if it exceeds the given width."
)

result_image = generate_text_image(width, text)

# Save the image to a file
output_filename = "build/output_image.png"
cv2.imwrite(output_filename, result_image)

# Assert that the file exists
assert os.path.exists(output_filename), "Output image file not found!"
3 changes: 2 additions & 1 deletion zsos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import frontier_exploration
from habitat import get_config

import zsos.obs_transformers.resize
from zsos.policy import base_policy, llm_policy
2 changes: 1 addition & 1 deletion zsos/policy/llm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import numpy as np
import torch
from frontier_exploration.policy import FrontierExplorationPolicy
from habitat.tasks.nav.object_nav_task import ObjectGoalSensor
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.tensor_dict import TensorDict
from habitat_baselines.rl.ppo.policy import PolicyActionData
from torch import Tensor

from frontier_exploration.policy import FrontierExplorationPolicy
from zsos.llm.llm import BaseLLM, ClientFastChat
from zsos.mapping.object_map import ObjectMap
from zsos.obs_transformers.resize import image_resize
Expand Down
6 changes: 4 additions & 2 deletions zsos/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from habitat_baselines.run import execute_exp
from omegaconf import DictConfig

from zsos.policy import base_policy, llm_policy # noqa: F401


@hydra.main(
version_base=None,
config_path="../habitat-lab/habitat-baselines/habitat_baselines/config",
config_name="pointnav/ppo_pointnav_example",
config_path="../config",
config_name="experiments/llm_objectnav_hm3d",
)
def main(cfg: DictConfig):
cfg = patch_config(cfg)
Expand Down
23 changes: 15 additions & 8 deletions zsos/vlm/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,27 @@ def __init__(
)
self.device = device

def ask(self, image, prompt=None):
def ask(self, image, prompt=None) -> str:
"""Generates a caption for the given image.
Args:
image (numpy.ndarray): The input image as a numpy array.
prompt (str, optional): An optional prompt to provide context and guide
the caption generation. Can be used to ask questions about the image.
Returns:
dict: The generated caption.
"""
pil_img = Image.fromarray(image)
processed_image = (
self.vis_processors["eval"](pil_img).unsqueeze(0).to(self.device)
)

import time

st = time.time()
if prompt is None or prompt == "":
out = self.model.generate({"image": processed_image})
out = self.model.generate({"image": processed_image})[0]
else:
out = self.model.generate({"image": processed_image, "prompt": prompt})
print(f"Time taken: {time.time() - st:.2f}s")
out = self.model.generate({"image": processed_image, "prompt": prompt})[0]

return out

Expand Down Expand Up @@ -68,7 +75,7 @@ def ask(self, image: np.ndarray, prompt: Optional[str] = None) -> str:
class BLIP2Server(ServerMixin, BLIP2):
def process_payload(self, payload: dict) -> dict:
image = str_to_image(payload["image"])
return {"response": self.ask(image, payload.get("prompt"))[0]}
return {"response": self.ask(image, payload.get("prompt"))}

# blip = BLIP2Server(name="blip2_opt", model_type="pretrain_opt2.7b")
blip = BLIP2Server(name="blip2_t5", model_type="pretrain_flant5xl")
Expand Down
2 changes: 1 addition & 1 deletion zsos/vlm/grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import numpy as np
import torch
import torchvision.transforms.functional as F

from groundingdino.util.inference import load_model, predict

from zsos.vlm.detections import ObjectDetections

from .server_wrapper import ServerMixin, host_model, send_request, str_to_image
Expand Down

0 comments on commit d6cd20e

Please sign in to comment.