Skip to content

Commit

Permalink
prettify output
Browse files Browse the repository at this point in the history
  • Loading branch information
liweijiang committed Jul 18, 2024
1 parent 86dbea5 commit adae33e
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 11 deletions.
157 changes: 147 additions & 10 deletions haicosystem/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import rich
from rich import print
from rich.panel import Panel
from rich.json import JSON
import gin
import logging
import itertools
Expand Down Expand Up @@ -87,43 +88,178 @@ def sample(
yield env, agents


def pick_color_for_agent(
agent_name: str,
available_colors: list[str],
occupied_colors: list[str],
agent_color_map: dict[str, str],
) -> tuple[list[str], list[str], dict[str, str], str]:
"""Pick a color for the agent based on the agent name and the available colors."""

if agent_name in agent_color_map:
return (
available_colors,
occupied_colors,
agent_color_map,
agent_color_map[agent_name],
)
else:
if available_colors:
color = available_colors.pop(0)
agent_color_map[agent_name] = color
occupied_colors.append(color)
return available_colors, occupied_colors, agent_color_map, color
else:
return available_colors, occupied_colors, agent_color_map, "white"


def parse_reasoning(reasoning: str, num_agents: int) -> tuple[list[str], str]:
"""Parse the reasoning string into a dictionary."""
sep_token = "SEPSEP"
for i in range(1, num_agents + 1):
reasoning = (
reasoning.replace(f"Agent {i} comments:\n", sep_token)
.strip(" ")
.strip("\n")
)
all_chunks = reasoning.split(sep_token)
general_comment = all_chunks[0].strip(" ").strip("\n")
comment_chunks = all_chunks[-num_agents:]

return comment_chunks, general_comment


def render_for_humans(episode: EpisodeLog) -> None:
"""Generate a human readable version of the episode log."""

available_colors = ["red", "green", "blue", "yellow"]
occupied_colors = []
agent_color_map = {}

for idx, turn in enumerate(episode.messages):
is_observation_printed = False

if idx == 0:
assert (
len(turn) >= 2
), "The first turn should have at least environment messages"
print(Panel(turn[0][2], title="Background Info", style="blue"))
print(Panel(turn[1][2], title="Background Info", style="blue"))

print(
Panel(
turn[0][2],
title="Background Info",
style="blue",
title_align="left",
)
)
print(
Panel(
turn[1][2],
title="Background Info",
style="blue",
title_align="left",
)
)
print("=" * 100)
print("Start Simulation")
print("=" * 100)

for sender, receiver, message in turn:
if "Observation:" in message and idx != 0:
# "Observation" indicates the agent takes an action
if not is_observation_printed and "Observation:" in message and idx != 0:
extract_observation = message.split("Observation:")[1].strip()
if extract_observation:
print(
Panel(
f"Observation: {extract_observation}",
JSON(extract_observation, highlight=False),
title="Observation",
style="yellow",
title_align="left",
)
)
is_observation_printed = True

if receiver == "Environment":
if sender != "Environment":
if "did nothing" in message:
continue
else:
# Conversation
if "said:" in message:
print(Panel(f"{sender} {message}", style="green"))
(
available_colors,
occupied_colors,
agent_color_map,
sender_color,
) = pick_color_for_agent(
sender,
available_colors,
occupied_colors,
agent_color_map,
)
message = message.split("said:")[1].strip()
print(
Panel(
f"{message}",
title=f"{sender} (Said)",
style=sender_color,
title_align="left",
)
)
# Action
else:
print(Panel(f"{sender}: {message}", style="blue"))
(
available_colors,
occupied_colors,
agent_color_map,
sender_color,
) = pick_color_for_agent(
sender,
available_colors,
occupied_colors,
agent_color_map,
)
message = message.replace("[action]", "")
print(
Panel(
JSON(message, highlight=False),
title=f"{sender} (Action)",
style=sender_color,
title_align="left",
)
)
else:
print(Panel(message, style="blue"))
print(Panel(message, style="white"))
print("=" * 100)
print("End Simulation")
print("=" * 100)

reasoning_per_agent, general_comment = parse_reasoning(
episode.reasoning, len(agent_color_map)
)

print(f"The reasoning is:\n{episode.reasoning}")
print(
f"The rewards are:\nAgent 1: {episode.rewards[0]}\nAgent 2: {episode.rewards[1]}"
Panel(
general_comment,
title="General (Comments)",
style="blue",
title_align="left",
)
)

for idx, reasoning in enumerate(reasoning_per_agent):
print(
Panel(
f"{reasoning}\n"
+ "=" * 100
+ "\nRewards: "
+ str(episode.rewards[idx]),
title=f"Agent {idx + 1} (Comments)",
style="blue",
title_align="left",
)
)


@gin.configurable
async def arun_one_episode(
Expand Down Expand Up @@ -300,7 +436,8 @@ async def run_server(
n_agent=len(agents_model_dict),
env_params=env_params,
agents_params=[
{"model_name": model_name} if model_name != "bridge" else {} # type: ignore
# type: ignore
{"model_name": model_name} if model_name != "bridge" else {}
for model_name in agents_model_dict.values()
],
)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/render_for_human.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from haicosystem.server import render_for_humans
from sotopia.database import EpisodeLog

episode = EpisodeLog.find(EpisodeLog.tag == "haicosystem_debug")[0] # type: ignore
episode = EpisodeLog.find(EpisodeLog.tag == "haicosystem_debug")[1] # type: ignore
assert isinstance(episode, EpisodeLog)
render_for_humans(episode)

0 comments on commit adae33e

Please sign in to comment.