"""Various wrappers for Parallel MO environments."""

from collections import namedtuple
from typing import Optional

import numpy as np
from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.wrappers.normalize import RunningMeanStd
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper

from momaland.learning.utils import remap_actions
from momaland.utils.env import MOParallelEnv

[docs] class RecordEpisodeStatistics(BaseParallelWrapper): """This wrapper will record episode statistics and print them at the end of each episode.""" def __init__(self, env): """This wrapper will record episode statistics and print them at the end of each episode. Args: env (env): The environment to apply the wrapper """ BaseParallelWrapper.__init__(self, env) self.episode_rewards = {agent: 0 for agent in self.possible_agents} self.episode_lengths = {agent: 0 for agent in self.possible_agents} def step(self, actions): """Steps through the environment, recording episode statistics.""" obs, rews, terminateds, truncateds, infos = super().step(actions) for agent in self.env.possible_agents: self.episode_rewards[agent] += rews[agent] self.episode_lengths[agent] += 1 if all(terminateds.values()) or all(truncateds.values()): infos["episode"] = { "r": self.episode_rewards, "l": self.episode_lengths, } return obs, rews, terminateds, truncateds, infos def reset(self, seed: Optional[int] = None, options: Optional[dict] = None): """Resets the environment and the episode statistics.""" obs, info = super().reset(seed, options) for agent in self.env.possible_agents: self.episode_rewards[agent] = 0 self.episode_lengths[agent] = 0 return obs, info
[docs] class LinearizeReward(BaseParallelWrapper): """Convert MO reward vector into scalar SO reward value. `weights` represents the weights of each objective in the reward vector space for each agent. Example: >>> weights = {"agent_0": np.array([0.1, 0.9]), "agent_1": np.array([0.2, 0.8])} ... env = LinearizeReward(env, weights) """ def __init__(self, env, weights: dict): """Reward linearization class initializer. Args: env: base env to add the wrapper on. weights: a dict where keys are agents and values are vectors representing the weights of their rewards. """ self.weights = weights super().__init__(env) def step(self, actions): """Returns a reward scalar from the reward vector.""" observations, rewards, terminations, truncations, infos = self.env.step(actions) for key in rewards: if key not in list(self.weights): continue rewards[key] =[key], self.weights[key]) return observations, rewards, terminations, truncations, infos
[docs] class NormalizeReward(BaseParallelWrapper): r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. The exponential moving average will have variance :math:`(1 - \gamma)^2`. Note: The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly instantiated or the policy was changed recently. Example: >>> for agent in env.possible_agents: ... for idx in range(env.reward_space(agent).shape[0]): ... env = AECWrappers.NormalizeReward(env, agent, idx) """ def __init__( self, env, agent, idx, gamma: float = 0.99, epsilon: float = 1e-8, ): """This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. Args: env: The environment to apply the wrapper agent: the agent whose reward will be normalized idx: the index of the rewards that will be normalized. epsilon: A stability parameter gamma: The discount factor that is used in the exponential moving average. """ super().__init__(env) self.agent = agent self.idx = idx self.return_rms = RunningMeanStd(shape=()) self.returns = np.array([0.0]) self.gamma = gamma self.epsilon = epsilon def step(self, actions): """Steps through the environment, normalizing the rewards returned.""" observations, rewards, terminations, truncations, infos = self.env.step(actions) # Extracts the objective value to normalize to_normalize = ( rewards[self.agent][self.idx] if isinstance(rewards[self.agent], np.ndarray) else rewards[self.agent] ) # array vs float self.returns = self.returns * self.gamma * (1 - terminations[self.agent]) + to_normalize # Defer normalization to gym implementation to_normalize = self.normalize(to_normalize) # Injecting the normalized objective value back into the reward vector # array vs float if isinstance(rewards[self.agent], np.ndarray): rewards[self.agent][self.idx] = to_normalize else: rewards[self.agent] = to_normalize return observations, rewards, terminations, truncations, infos def normalize(self, rews): """Normalizes the rewards with the running mean rewards and their variance.""" self.return_rms.update(self.returns) return rews / np.sqrt(self.return_rms.var + self.epsilon)
class CentraliseAgent(BaseParallelWrapper): """This wrapper will create a central agent that observes the full state of the environment. The central agent will receive the concatenation of all agents' observations as its own observation (or a global state, if available in the environment), and a multi-objective reward vector (representing the component-wise sum of the individual agent rewards) as its own reward. The central agent is expected to return a vector of actions, one for each agent in the original environment. """ def __init__(self, env: MOParallelEnv, action_mapping=False, reward_type="sum"): """Central agent wrapper class initializer. Args: env: The parallel environment to apply the wrapper action_mapping: Whether to use an action mapping to Discrete spaces of not reward_type: The type of reward grouping to use, either 'sum' or 'mean' """ super().__init__(env) self.action_mapping = action_mapping self.unwrapped.spec = namedtuple("Spec", ["id"]) = self.env.metadata.get("name") self._reward_type = reward_type if self.env.metadata.get("central_observation"): self.observation_space = env.get_central_observation_space() self.unwrapped.observation_space = env.get_central_observation_space() else: self.observation_space = Dict({agentID: env.observation_space(agentID) for agentID in self.possible_agents}) # self.action_space = Dict({agentID: env.action_space(agentID) for agentID in self.possible_agents}) # For compatibility with MORL baselines # Make the action space a Box space with the same bounds as the first agent's action space ag0_action_space = env.action_space(self.possible_agents[0]) self.num_actions = ag0_action_space.n if self.action_mapping: self.action_space = Discrete(self.num_actions ** len(self.possible_agents)) self.unwrapped.action_space = self.action_space elif self.env.metadata.get("central_observation"): self.action_space = Box( low=ag0_action_space.start, high=(ag0_action_space.n - 1), shape=(len(self.possible_agents),), dtype=ag0_action_space.dtype, ) else: self.action_space = Dict({agentID: env.action_space(agentID) for agentID in self.possible_agents}) self.reward_space = self.env.reward_space(self.possible_agents[0]) self.unwrapped.reward_space = self.reward_space def step(self, actions): """Steps through the environment, joining the returned values for the central agent.""" # Remake the action list into a dictionary compatible with MOMAland environments if self.action_mapping: remapped_actions = remap_actions(actions, len(self.agents), self.num_actions) actions = {agent: remapped_actions[i] for i, agent in enumerate(self.agents)} elif self.env.metadata.get("central_observation"): actions = {agent: actions[num] for num, agent in enumerate(self.possible_agents)} observations, rewards, terminations, truncations, infos = self.env.step(actions) if self.env.metadata.get("central_observation"): observations = self.env.state().flatten() if self._reward_type == "sum": joint_reward = np.sum(list(rewards.values()), axis=0) else: joint_reward = np.mean(list(rewards.values()), axis=0) return ( observations, joint_reward, np.any(list(terminations.values())), np.any(list(truncations.values())), infos, ) def reset(self, seed=None): """Resets the environment, joining the returned values for the central agent.""" observations, infos = self.env.reset(seed) if self.env.metadata.get("central_observation"): observations = self.env.state().flatten() return observations, list(infos.values())