@simmax21 I had to make a custom environment with the help of @aakarshanc01
If you can further improve on this code would also be amazing for me and other people that come after us:
def get_wandb_ue_env():
# engine config
engine_channel = EngineConfigurationChannel()
engine_channel.set_configuration_parameters(time_scale=config.time_scale)
# side channels
channel = SB3StatsRecorder()
# environment
env = UE(config.env_path,
seed=1,
worker_id=rank,
base_port=5000 + rank,
no_graphics=config.no_graphics,
side_channels=[engine_channel, channel])
return env
class CustomEnv(gym.Env):
def __init__(self):
super(CustomEnv, self).__init__()
env = get_wandb_ue_env()
env = UnityToGymWrapper(env, allow_multiple_obs=True)
self.env = env
self.action_space = self.env.action_space
self.action_size = self.env.action_size
self.observation_space = gym.spaces.Dict({
0: gym.spaces.Box(low=0, high=1, shape=(27, 60, 3)), # =(40, 90, 3)),
1: gym.spaces.Box(low=0, high=1, shape=(20, 40, 1)), # (56, 121, 1
2: gym.spaces.Box(low='-inf', high='inf', shape=(400,))
})
@staticmethod
def tuple_to_dict(s):
obs = {
0: s[0],
1: s[1],
2: s[2]
}
return obs
def reset(self):
# print("LOG: returning reset" + self.tuple_to_dict(self.env.reset()))
# print("LOG: returning reset" + (self.env.reset()))
# np.array(self._observation)
return self.tuple_to_dict(self.env.reset())
def step(self, action):
s, r, d, info = self.env.step(action)
return self.tuple_to_dict(s), float(r), d, info
def close(self):
self.env.close()
global rank
rank -= 1
def render(self, mode="human"):
self.env.render()
class SB3StatsRecorder(SideChannel):
"""
Side channel that receives (string, float) pairs from the environment, so that they can eventually
be passed to a StatsReporter.
"""
def __init__(self) -> None:
# >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/StatsSideChannel")
# UUID('a1d8f7b7-cec8-50f9-b78b-d3e165a78520')
super().__init__(uuid.UUID("a1d8f7b7-cec8-50f9-b78b-d3e165a78520"))
pretty_print("Initializing SB3StatsRecorder", Colors.FAIL)
self.stats: EnvironmentStats = defaultdict(list)
self.i = 0
self.wandb_tables: dict = {}
def on_message_received(self, msg: IncomingMessage) -> None:
"""
Receive the message from the environment, and save it for later retrieval.
:param msg:
:return:
"""
key = msg.read_string()
val = msg.read_float32()
agg_type = StatsAggregationMethod(msg.read_int32())
self.stats[key].append((val, agg_type))
# assign different Drone[id] to each subprocess within this wandb run
key = key.split("/")[1]
self.i += 1
if env_callback is not None and wandb_run_identifier == "test": # and "Speed" in "val"
# if self.i % 100 == 0:
my_table_id: str = "Performance[{}]".format(wandb_run_identifier)
# pretty_print("Publishing Table: key: {}, val: {}".format(my_table_id, key, val), Colors.FAIL)
env_callback(my_table_id, key, val)
def get_and_reset_stats(self) -> EnvironmentStats:
"""
Returns the current stats, and resets the internal storage of the stats.
:return:
"""
s = self.stats
self.stats = defaultdict(list)
return s