Source code for momaland.utils.parallel_wrappers

"""Various wrappers for Parallel MO environments."""
from collections import namedtuple
from typing import Optional

import numpy as np
from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.wrappers.normalize import RunningMeanStd
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper

from momaland.learning.utils import remap_actions
from momaland.utils.env import MOParallelEnv


[docs] class RecordEpisodeStatistics(BaseParallelWrapper): """This wrapper will record episode statistics and print them at the end of each episode.""" def __init__(self, env): """This wrapper will record episode statistics and print them at the end of each episode. Args: env (env): The environment to apply the wrapper """ BaseParallelWrapper.__init__(self, env) self.episode_rewards = {agent: 0 for agent in self.possible_agents} self.episode_lengths = {agent: 0 for agent in self.possible_agents} def step(self, actions): """Steps through the environment, recording episode statistics.""" obs, rews, terminateds, truncateds, infos = super().step(actions) for agent in self.env.possible_agents: self.episode_rewards[agent] += rews[agent] self.episode_lengths[agent] += 1 if all(terminateds.values()) or all(truncateds.values()): infos["episode"] = { "r": self.episode_rewards, "l": self.episode_lengths, } return obs, rews, terminateds, truncateds, infos def reset(self, seed: Optional[int] = None, options: Optional[dict] = None): """Resets the environment and the episode statistics.""" obs, info = super().reset(seed, options) for agent in self.env.possible_agents: self.episode_rewards[agent] = 0 self.episode_lengths[agent] = 0 return obs, info
[docs] class LinearizeReward(BaseParallelWrapper): """Convert MO reward vector into scalar SO reward value. `weights` represents the weights of each objective in the reward vector space for each agent. Example: >>> weights = {"agent_0": np.array([0.1, 0.9]), "agent_1": np.array([0.2, 0.8])} ... env = LinearizeReward(env, weights) """ def __init__(self, env, weights: dict): """Reward linearization class initializer. Args: env: base env to add the wrapper on. weights: a dict where keys are agents and values are vectors representing the weights of their rewards. """ self.weights = weights super().__init__(env) def step(self, actions): """Returns a reward scalar from the reward vector.""" observations, rewards, terminations, truncations, infos = self.env.step(actions) for key in rewards: if key not in list(self.weights): continue rewards[key] = np.dot(rewards[key], self.weights[key]) return observations, rewards, terminations, truncations, infos
[docs] class NormalizeReward(BaseParallelWrapper): r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. The exponential moving average will have variance :math:`(1 - \gamma)^2`. Note: The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly instantiated or the policy was changed recently. Example: >>> for agent in env.possible_agents: ... for idx in range(env.reward_space(agent).shape[0]): ... env = AECWrappers.NormalizeReward(env, agent, idx) """ def __init__( self, env, agent, idx, gamma: float = 0.99, epsilon: float = 1e-8, ): """This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. Args: env: The environment to apply the wrapper agent: the agent whose reward will be normalized idx: the index of the rewards that will be normalized. epsilon: A stability parameter gamma: The discount factor that is used in the exponential moving average. """ super().__init__(env) self.agent = agent self.idx = idx self.return_rms = RunningMeanStd(shape=()) self.returns = np.array([0.0]) self.gamma = gamma self.epsilon = epsilon def step(self, actions): """Steps through the environment, normalizing the rewards returned.""" observations, rewards, terminations, truncations, infos = self.env.step(actions) # Extracts the objective value to normalize to_normalize = ( rewards[self.agent][self.idx] if isinstance(rewards[self.agent], np.ndarray) else rewards[self.agent] ) # array vs float self.returns = self.returns * self.gamma * (1 - terminations[self.agent]) + to_normalize # Defer normalization to gym implementation to_normalize = self.normalize(to_normalize) # Injecting the normalized objective value back into the reward vector # array vs float if isinstance(rewards[self.agent], np.ndarray): rewards[self.agent][self.idx] = to_normalize else: rewards[self.agent] = to_normalize return observations, rewards, terminations, truncations, infos def normalize(self, rews): """Normalizes the rewards with the running mean rewards and their variance.""" self.return_rms.update(self.returns) return rews / np.sqrt(self.return_rms.var + self.epsilon)
class CentraliseAgent(BaseParallelWrapper): """This wrapper will create a central agent that observes the full state of the environment. The central agent will receive the concatenation of all agents' observations as its own observation (or a global state, if available in the environment), and a multi-objective reward vector (representing the component-wise sum of the individual agent rewards) as its own reward. The central agent is expected to return a vector of actions, one for each agent in the original environment. """ def __init__(self, env: MOParallelEnv, action_mapping=False, reward_type="sum"): """Central agent wrapper class initializer. Args: env: The parallel environment to apply the wrapper action_mapping: Whether to use an action mapping to Discrete spaces of not reward_type: The type of reward grouping to use, either 'sum' or 'mean' """ super().__init__(env) self.action_mapping = action_mapping self.unwrapped.spec = namedtuple("Spec", ["id"]) self.unwrapped.spec.id = self.env.metadata.get("name") self._reward_type = reward_type if self.env.metadata.get("central_observation"): self.observation_space = env.get_central_observation_space() self.unwrapped.observation_space = env.get_central_observation_space() else: self.observation_space = Dict({agentID: env.observation_space(agentID) for agentID in self.possible_agents}) # self.action_space = Dict({agentID: env.action_space(agentID) for agentID in self.possible_agents}) # For compatibility with MORL baselines # Make the action space a Box space with the same bounds as the first agent's action space ag0_action_space = env.action_space(self.possible_agents[0]) self.num_actions = ag0_action_space.n if self.action_mapping: self.action_space = Discrete(self.num_actions ** len(self.possible_agents)) self.unwrapped.action_space = self.action_space elif self.env.metadata.get("central_observation"): self.action_space = Box( low=ag0_action_space.start, high=(ag0_action_space.n - 1), shape=(len(self.possible_agents),), dtype=ag0_action_space.dtype, ) else: self.action_space = Dict({agentID: env.action_space(agentID) for agentID in self.possible_agents}) self.reward_space = self.env.reward_space(self.possible_agents[0]) self.unwrapped.reward_space = self.reward_space def step(self, actions): """Steps through the environment, joining the returned values for the central agent.""" # Remake the action list into a dictionary compatible with MOMAland environments if self.action_mapping: remapped_actions = remap_actions(actions, len(self.agents), self.num_actions) actions = {agent: remapped_actions[i] for i, agent in enumerate(self.agents)} elif self.env.metadata.get("central_observation"): actions = {agent: actions[num] for num, agent in enumerate(self.possible_agents)} observations, rewards, terminations, truncations, infos = self.env.step(actions) if self.env.metadata.get("central_observation"): observations = self.env.state().flatten() if self._reward_type == "sum": joint_reward = np.sum(list(rewards.values()), axis=0) else: joint_reward = np.mean(list(rewards.values()), axis=0) return ( observations, joint_reward, np.any(list(terminations.values())), np.any(list(truncations.values())), infos, ) def reset(self, seed=None): """Resets the environment, joining the returned values for the central agent.""" observations, infos = self.env.reset(seed) if self.env.metadata.get("central_observation"): observations = self.env.state().flatten() return observations, list(infos.values())