Source code for momaland.utils.conversions

"""Copied and pasted from PZ, adapted to handle vectorial rewards."""

import copy
from collections import defaultdict
from typing_extensions import override

import numpy as np
from pettingzoo.utils import agent_selector
from pettingzoo.utils.conversions import (
    ActionType,
    aec_to_parallel_wrapper,
    parallel_to_aec_wrapper,
)
from pettingzoo.utils.env import AECEnv, ParallelEnv

from momaland.utils.env import MOAECEnv, MOParallelEnv


def mo_aec_to_parallel(aec_env: AECEnv) -> ParallelEnv:
    """Converts a MO aec environment to a parallel environment.

    In the case of an existing parallel environment wrapped using a `parallel_to_aec_wrapper`, this function will return the original parallel environment.
    Otherwise, it will apply the `aec_to_parallel_wrapper` to convert the environment.
    """
    if isinstance(aec_env, parallel_to_aec_wrapper):
        return aec_env.env
    else:
        par_env = mo_aec_to_parallel_wrapper(aec_env)
        return par_env


def mo_parallel_to_aec(par_env: ParallelEnv) -> AECEnv:
    """Converts a MO aec environment to a parallel environment.

    In the case of an existing aec environment wrapped using a `aec_to_prallel_wrapper`, this function will return the original AEC environment.
    Otherwise, it will apply the `parallel_to_aec_wrapper` to convert the environment.
    """
    if isinstance(par_env, aec_to_parallel_wrapper):
        return par_env.aec_env
    else:
        aec_env = mo_parallel_to_aec_wrapper(par_env)
        return aec_env


[docs] class mo_aec_to_parallel_wrapper(aec_to_parallel_wrapper, MOParallelEnv): """Converts an AEC environment into a Parallel environment. Overrides PZ behavior to handle vectorial rewards. Keeping inheritance avoids code duplication and checks for instance type. """ def __init__(self, aec_env): """Converts an MO AEC environment into a MO Parallel environment.""" super().__init__(aec_env) @property def reward_spaces(self): """Returns the reward spaces of the environment.""" return self.aec_env.reward_spaces @override def reward_space(self, agent): return self.aec_env.reward_spaces[agent] @override def step(self, actions): rewards = defaultdict(lambda: np.zeros(self.reward_space(self.aec_env.agents[0]).shape)) terminations = {} truncations = {} infos = {} observations = {} for agent in self.aec_env.agents: if agent != self.aec_env.agent_selection: if self.aec_env.terminations[agent] or self.aec_env.truncations[agent]: raise AssertionError( f"expected agent {agent} got termination or truncation agent {self.aec_env.agent_selection}. Parallel environment wrapper expects all agent death (setting an agent's self.terminations or self.truncations entry to True) to happen only at the end of a cycle." ) else: raise AssertionError( f"expected agent {agent} got agent {self.aec_env.agent_selection}, Parallel environment wrapper expects agents to step in a cycle." ) obs, rew, termination, truncation, info = self.aec_env.last() self.aec_env.step(actions[agent]) for agent in self.aec_env.agents: rewards[agent] += self.aec_env.rewards[agent] terminations = dict(**self.aec_env.terminations) truncations = dict(**self.aec_env.truncations) infos = dict(**self.aec_env.infos) observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents} while self.aec_env.agents and ( self.aec_env.terminations[self.aec_env.agent_selection] or self.aec_env.truncations[self.aec_env.agent_selection] ): self.aec_env.step(None) self.agents = self.aec_env.agents return observations, rewards, terminations, truncations, infos
[docs] class mo_parallel_to_aec_wrapper(parallel_to_aec_wrapper, MOAECEnv): """Converts a parallel environment into an AEC environment. Overrides PZ behavior to handle vectorial rewards. Keeping inheritance avoids code duplication and checks for instance type. """ def __init__(self, parallel_env): """Converts a MO parallel environment into an MO AEC environment.""" super().__init__(parallel_env) @property def reward_spaces(self): """Returns the reward spaces of the environment.""" return self.env.reward_spaces @override def reward_space(self, agent): return self.env.reward_spaces[agent] @override def reset(self, seed=None, options=None): super().reset(seed, options) self.rewards = {agent: np.zeros(self.reward_space(agent).shape[0], dtype=np.float32) for agent in self.agents} self._cumulative_rewards = { agent: np.zeros(self.reward_space(agent).shape[0], dtype=np.float32) for agent in self.agents } @override def add_new_agent(self, new_agent): super().add_new_agent(new_agent) self.rewards[new_agent] = np.zeros(self.reward_space(new_agent).shape[0], dtype=np.float32) self._cumulative_rewards[new_agent] = np.zeros(self.reward_space(new_agent).shape[0], dtype=np.float32) @override def step(self, action: ActionType): if self.terminations[self.agent_selection] or self.truncations[self.agent_selection]: del self._actions[self.agent_selection] assert action is None self._was_dead_step(action) return self._actions[self.agent_selection] = action if self._agent_selector.is_last(): obss, rews, terminations, truncations, infos = self.env.step(self._actions) self._observations = copy.copy(obss) self.terminations = copy.copy(terminations) self.truncations = copy.copy(truncations) self.infos = copy.copy(infos) self.rewards = copy.copy(rews) self._cumulative_rewards = copy.copy(rews) env_agent_set = set(self.env.agents) self.agents = self.env.agents + [ agent for agent in sorted(self._observations.keys()) if agent not in env_agent_set ] if len(self.env.agents): self._agent_selector = agent_selector(self.env.agents) self.agent_selection = self._agent_selector.reset() self._deads_step_first() else: if self._agent_selector.is_first(): self._clear_rewards() self.agent_selection = self._agent_selector.next() @override def _clear_rewards(self) -> None: MOAECEnv._clear_rewards(self)