diff --git a/haicosystem/server.py b/haicosystem/server.py index ac70c8e..ade2ec4 100644 --- a/haicosystem/server.py +++ b/haicosystem/server.py @@ -1,7 +1,6 @@ import asyncio import rich from rich import print -from rich.panel import Panel import gin import logging import itertools @@ -16,7 +15,6 @@ from sotopia.envs.evaluators import ( RuleBasedTerminatedEvaluator, ) -from sotopia.generation_utils.generate import LLM_Name from sotopia.messages import AgentAction, Message, Observation from sotopia.database import AgentProfile from sotopia.samplers import BaseSampler, EnvAgentCombo @@ -27,6 +25,7 @@ from haicosystem.agents import LLMAgentX from haicosystem.envs.evaluators import SafetyLLMEvaluator from haicosystem.grounding_engine import LLMGroundingEngine +from haicosystem.utils.render import render_for_humans ObsType = TypeVar("ObsType") ActType = TypeVar("ActType") @@ -87,44 +86,6 @@ def sample( yield env, agents -def render_for_humans(episode: EpisodeLog) -> None: - """Generate a human readable version of the episode log.""" - for idx, turn in enumerate(episode.messages): - if idx == 0: - assert ( - len(turn) >= 2 - ), "The first turn should have at least environment messages" - print(Panel(turn[0][2], title="Background Info", style="blue")) - print(Panel(turn[1][2], title="Background Info", style="blue")) - for sender, receiver, message in turn: - if "Observation:" in message and idx != 0: - extract_observation = message.split("Observation:")[1].strip() - if extract_observation: - print( - Panel( - f"Observation: {extract_observation}", - title="Observation", - style="yellow", - ) - ) - if receiver == "Environment": - if sender != "Environment": - if "did nothing" in message: - continue - else: - if "said:" in message: - print(Panel(f"{sender} {message}", style="green")) - else: - print(Panel(f"{sender}: {message}", style="blue")) - else: - print(Panel(message, style="blue")) - - print(f"The reasoning is:\n{episode.reasoning}") - print( - f"The rewards are:\nAgent 1: {episode.rewards[0]}\nAgent 2: {episode.rewards[1]}" - ) - - @gin.configurable async def arun_one_episode( env: ParellelHaicosystemEnv, @@ -240,7 +201,7 @@ def get_agent_class( @beartype async def run_server( - model_dict: dict[str, LLM_Name], + model_dict: dict[str, str], agents_roles: dict[str, str], sampler: BaseSampler[Observation, AgentAction] = BridgeSampler(), action_order: Literal["simutaneous", "round-robin", "random"] = "round-robin", @@ -300,7 +261,8 @@ async def run_server( n_agent=len(agents_model_dict), env_params=env_params, agents_params=[ - {"model_name": model_name} if model_name != "bridge" else {} # type: ignore + # {"model_name": model_name} if model_name != "bridge" else {} for model_name in agents_model_dict.values() + {"model_name": model_name} if model_name != "bridge" else {} for model_name in agents_model_dict.values() ], ) diff --git a/haicosystem/utils/render.py b/haicosystem/utils/render.py new file mode 100644 index 0000000..0360516 --- /dev/null +++ b/haicosystem/utils/render.py @@ -0,0 +1,182 @@ +from rich import print +from rich.panel import Panel +from rich.json import JSON +from sotopia.database import EpisodeLog + + +def pick_color_for_agent( + agent_name: str, + available_colors: list[str], + occupied_colors: list[str], + agent_color_map: dict[str, str], +) -> tuple[list[str], list[str], dict[str, str], str]: + """Pick a color for the agent based on the agent name and the available colors.""" + + if agent_name in agent_color_map: + return ( + available_colors, + occupied_colors, + agent_color_map, + agent_color_map[agent_name], + ) + else: + if available_colors: + color = available_colors.pop(0) + agent_color_map[agent_name] = color + occupied_colors.append(color) + return available_colors, occupied_colors, agent_color_map, color + else: + return available_colors, occupied_colors, agent_color_map, "white" + + +def parse_reasoning(reasoning: str, num_agents: int) -> tuple[list[str], str]: + """Parse the reasoning string into a dictionary.""" + sep_token = "SEPSEP" + for i in range(1, num_agents + 1): + reasoning = ( + reasoning.replace(f"Agent {i} comments:\n", sep_token) + .strip(" ") + .strip("\n") + ) + all_chunks = reasoning.split(sep_token) + general_comment = all_chunks[0].strip(" ").strip("\n") + comment_chunks = all_chunks[-num_agents:] + + return comment_chunks, general_comment + + +def render_for_humans(episode: EpisodeLog) -> None: + """Generate a human readable version of the episode log.""" + + available_colors: list[str] = [ + "medium_violet_red", + "green", + "slate_blue1", + "yellow", + ] + occupied_colors: list[str] = [] + agent_color_map: dict[str, str] = {} + + for idx, turn in enumerate(episode.messages): + is_observation_printed = False + + if idx == 0: + assert ( + len(turn) >= 2 + ), "The first turn should have at least environment messages" + + print( + Panel( + turn[0][2], + title="Background Info", + style="blue", + title_align="left", + ) + ) + print( + Panel( + turn[1][2], + title="Background Info", + style="blue", + title_align="left", + ) + ) + print("=" * 100) + print("Start Simulation") + print("=" * 100) + + for sender, receiver, message in turn: + # "Observation" indicates the agent takes an action + if not is_observation_printed and "Observation:" in message and idx != 0: + extract_observation = message.split("Observation:")[1].strip() + if extract_observation: + print( + Panel( + JSON(extract_observation, highlight=False), + title="Observation", + style="yellow", + title_align="left", + ) + ) + is_observation_printed = True + + if receiver == "Environment": + if sender != "Environment": + if "did nothing" in message: + continue + else: + # Conversation + if "said:" in message: + ( + available_colors, + occupied_colors, + agent_color_map, + sender_color, + ) = pick_color_for_agent( + sender, + available_colors, + occupied_colors, + agent_color_map, + ) + message = message.split("said:")[1].strip() + print( + Panel( + f"{message}", + title=f"{sender} (Said)", + style=sender_color, + title_align="left", + ) + ) + # Action + else: + ( + available_colors, + occupied_colors, + agent_color_map, + sender_color, + ) = pick_color_for_agent( + sender, + available_colors, + occupied_colors, + agent_color_map, + ) + message = message.replace("[action]", "") + print( + Panel( + JSON(message, highlight=False), + title=f"{sender} (Action)", + style=sender_color, + title_align="left", + ) + ) + else: + print(Panel(message, style="white")) + print("=" * 100) + print("End Simulation") + print("=" * 100) + + reasoning_per_agent, general_comment = parse_reasoning( + episode.reasoning, len(agent_color_map) + ) + + print( + Panel( + general_comment, + title="General (Comments)", + style="blue", + title_align="left", + ) + ) + + for idx, reasoning in enumerate(reasoning_per_agent): + print( + Panel( + f"{reasoning}\n" + + "=" * 100 + + "\nRewards: " + + str(episode.rewards[idx]), + title=f"Agent {idx + 1} (Comments)", + style="blue", + title_align="left", + ) + ) diff --git a/notebooks/render_for_human.py b/notebooks/render_for_human.py index 85d129d..7888ba8 100644 --- a/notebooks/render_for_human.py +++ b/notebooks/render_for_human.py @@ -1,6 +1,6 @@ -from haicosystem.server import render_for_humans +from haicosystem.utils.render import render_for_humans from sotopia.database import EpisodeLog -episode = EpisodeLog.find(EpisodeLog.tag == "haicosystem_debug")[0] # type: ignore +episode = EpisodeLog.find(EpisodeLog.tag == "haicosystem_debug")[1] # type: ignore assert isinstance(episode, EpisodeLog) render_for_humans(episode)