Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prettify output #16

Merged
merged 3 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 4 additions & 42 deletions haicosystem/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import rich
from rich import print
from rich.panel import Panel
import gin
import logging
import itertools
Expand All @@ -16,7 +15,6 @@
from sotopia.envs.evaluators import (
RuleBasedTerminatedEvaluator,
)
from sotopia.generation_utils.generate import LLM_Name
from sotopia.messages import AgentAction, Message, Observation
from sotopia.database import AgentProfile
from sotopia.samplers import BaseSampler, EnvAgentCombo
Expand All @@ -27,6 +25,7 @@
from haicosystem.agents import LLMAgentX
from haicosystem.envs.evaluators import SafetyLLMEvaluator
from haicosystem.grounding_engine import LLMGroundingEngine
from haicosystem.utils.render import render_for_humans

ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
Expand Down Expand Up @@ -87,44 +86,6 @@ def sample(
yield env, agents


def render_for_humans(episode: EpisodeLog) -> None:
"""Generate a human readable version of the episode log."""
for idx, turn in enumerate(episode.messages):
if idx == 0:
assert (
len(turn) >= 2
), "The first turn should have at least environment messages"
print(Panel(turn[0][2], title="Background Info", style="blue"))
print(Panel(turn[1][2], title="Background Info", style="blue"))
for sender, receiver, message in turn:
if "Observation:" in message and idx != 0:
extract_observation = message.split("Observation:")[1].strip()
if extract_observation:
print(
Panel(
f"Observation: {extract_observation}",
title="Observation",
style="yellow",
)
)
if receiver == "Environment":
if sender != "Environment":
if "did nothing" in message:
continue
else:
if "said:" in message:
print(Panel(f"{sender} {message}", style="green"))
else:
print(Panel(f"{sender}: {message}", style="blue"))
else:
print(Panel(message, style="blue"))

print(f"The reasoning is:\n{episode.reasoning}")
print(
f"The rewards are:\nAgent 1: {episode.rewards[0]}\nAgent 2: {episode.rewards[1]}"
)


@gin.configurable
async def arun_one_episode(
env: ParellelHaicosystemEnv,
Expand Down Expand Up @@ -240,7 +201,7 @@ def get_agent_class(

@beartype
async def run_server(
model_dict: dict[str, LLM_Name],
model_dict: dict[str, str],
agents_roles: dict[str, str],
sampler: BaseSampler[Observation, AgentAction] = BridgeSampler(),
action_order: Literal["simutaneous", "round-robin", "random"] = "round-robin",
Expand Down Expand Up @@ -300,7 +261,8 @@ async def run_server(
n_agent=len(agents_model_dict),
env_params=env_params,
agents_params=[
{"model_name": model_name} if model_name != "bridge" else {} # type: ignore
# {"model_name": model_name} if model_name != "bridge" else {} for model_name in agents_model_dict.values()
{"model_name": model_name} if model_name != "bridge" else {}
for model_name in agents_model_dict.values()
],
)
Expand Down
182 changes: 182 additions & 0 deletions haicosystem/utils/render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from rich import print
from rich.panel import Panel
from rich.json import JSON
from sotopia.database import EpisodeLog


def pick_color_for_agent(
agent_name: str,
available_colors: list[str],
occupied_colors: list[str],
agent_color_map: dict[str, str],
) -> tuple[list[str], list[str], dict[str, str], str]:
"""Pick a color for the agent based on the agent name and the available colors."""

if agent_name in agent_color_map:
return (
available_colors,
occupied_colors,
agent_color_map,
agent_color_map[agent_name],
)
else:
if available_colors:
color = available_colors.pop(0)
agent_color_map[agent_name] = color
occupied_colors.append(color)
return available_colors, occupied_colors, agent_color_map, color
else:
return available_colors, occupied_colors, agent_color_map, "white"


def parse_reasoning(reasoning: str, num_agents: int) -> tuple[list[str], str]:
"""Parse the reasoning string into a dictionary."""
sep_token = "SEPSEP"
for i in range(1, num_agents + 1):
reasoning = (
reasoning.replace(f"Agent {i} comments:\n", sep_token)
.strip(" ")
.strip("\n")
)
all_chunks = reasoning.split(sep_token)
general_comment = all_chunks[0].strip(" ").strip("\n")
comment_chunks = all_chunks[-num_agents:]

return comment_chunks, general_comment


def render_for_humans(episode: EpisodeLog) -> None:
"""Generate a human readable version of the episode log."""

available_colors: list[str] = [
"medium_violet_red",
"green",
"slate_blue1",
"yellow",
]
occupied_colors: list[str] = []
agent_color_map: dict[str, str] = {}

for idx, turn in enumerate(episode.messages):
is_observation_printed = False

if idx == 0:
assert (
len(turn) >= 2
), "The first turn should have at least environment messages"

print(
Panel(
turn[0][2],
title="Background Info",
style="blue",
title_align="left",
)
)
print(
Panel(
turn[1][2],
title="Background Info",
style="blue",
title_align="left",
)
)
print("=" * 100)
print("Start Simulation")
print("=" * 100)

for sender, receiver, message in turn:
# "Observation" indicates the agent takes an action
if not is_observation_printed and "Observation:" in message and idx != 0:
extract_observation = message.split("Observation:")[1].strip()
if extract_observation:
print(
Panel(
JSON(extract_observation, highlight=False),
title="Observation",
style="yellow",
title_align="left",
)
)
is_observation_printed = True

if receiver == "Environment":
if sender != "Environment":
if "did nothing" in message:
continue
else:
# Conversation
if "said:" in message:
(
available_colors,
occupied_colors,
agent_color_map,
sender_color,
) = pick_color_for_agent(
sender,
available_colors,
occupied_colors,
agent_color_map,
)
message = message.split("said:")[1].strip()
print(
Panel(
f"{message}",
title=f"{sender} (Said)",
style=sender_color,
title_align="left",
)
)
# Action
else:
(
available_colors,
occupied_colors,
agent_color_map,
sender_color,
) = pick_color_for_agent(
sender,
available_colors,
occupied_colors,
agent_color_map,
)
message = message.replace("[action]", "")
print(
Panel(
JSON(message, highlight=False),
title=f"{sender} (Action)",
style=sender_color,
title_align="left",
)
)
else:
print(Panel(message, style="white"))
print("=" * 100)
print("End Simulation")
print("=" * 100)

reasoning_per_agent, general_comment = parse_reasoning(
episode.reasoning, len(agent_color_map)
)

print(
Panel(
general_comment,
title="General (Comments)",
style="blue",
title_align="left",
)
)

for idx, reasoning in enumerate(reasoning_per_agent):
print(
Panel(
f"{reasoning}\n"
+ "=" * 100
+ "\nRewards: "
+ str(episode.rewards[idx]),
title=f"Agent {idx + 1} (Comments)",
style="blue",
title_align="left",
)
)
4 changes: 2 additions & 2 deletions notebooks/render_for_human.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from haicosystem.server import render_for_humans
from haicosystem.utils.render import render_for_humans
from sotopia.database import EpisodeLog

episode = EpisodeLog.find(EpisodeLog.tag == "haicosystem_debug")[0] # type: ignore
episode = EpisodeLog.find(EpisodeLog.tag == "haicosystem_debug")[1] # type: ignore
assert isinstance(episode, EpisodeLog)
render_for_humans(episode)
Loading