Log per-step intrinsic curiosity reward (ICM) during training

Hi. Is there a chance to log the intrinsic curiosity reward (either ICM or RND) during training? In particular, I’m looking to logging the per-step reward, rather than cumulative reward over an episode. Is there e.g. a chance to overload the internal code for the curiosity reward? Thanks!

Here’s my solution, which in this form only works for a single agent/training process. It requires a modification of two files in the mlagents python-side code:

  1. trainers/trainer/rl_trainer.py:
class RLTrainer(Trainer):
    """
    This class is the base class for trainers that use Reward Signals.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
        # used for reporting only. We always want to report the environment reward to Tensorboard, regardless
        # of what reward signals are actually present.
        self.cumulative_returns_since_policy_update: List[float] = []

        # ADDITION BEGIN
        from collections import defaultdict
        self.intrinsic_rewards = defaultdict(list)
        # ADDITION END
    
        self.collected_rewards: Dict[str, Dict[str, int]] = {
            "environment": defaultdict(lambda: 0)
        }
        self.update_buffer: AgentBuffer = AgentBuffer()
        self._stats_reporter.add_property(
            StatsPropertyType.HYPERPARAMETERS, self.trainer_settings.as_dict()
        )
  1. trainers/ppo/trainer.py (substitute with the trainer for your specific learner):
# Evaluate all reward functions
        self.collected_rewards["environment"][agent_id] += np.sum(
            agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS]
        )
        for name, reward_signal in self.optimizer.reward_signals.items():
            raw_rewards = reward_signal.evaluate(agent_buffer_trajectory)
            evaluate_result = (raw_rewards * reward_signal.strength)
            agent_buffer_trajectory[RewardSignalUtil.rewards_key(name)].extend(
                evaluate_result
            )
            # Report the reward signals
            self.collected_rewards[name][agent_id] += np.sum(evaluate_result)

            # ADDITION BEGIN
            if name == "curiosity" or name == "rnd":
                self.intrinsic_rewards[name].append(raw_rewards.tolist())
            # ADDITION END
 # If this was a terminal trajectory, append stats and reset reward collection
        if trajectory.done_reached:
            self._update_end_episode_stats(agent_id, self.optimizer)

            # ADDITION BEGIN
            for reward_type, rewards in self.intrinsic_rewards.items():
                reward_list = list(chain.from_iterable(rewards))
                logger.info("%s (n=%d, episode done=%s, max steps reached=%s): %s" % (reward_type, len(reward_list), trajectory.done_reached, trajectory.interrupted, reward_list))
                self.intrinsic_rewards[reward_type] = [] # reset list
            # ADDITION END

Produced output on the Hummingbird tutorial with curiosity reward added to the config:

[INFO] curiosity (n=200, episode done=True, max steps reached=True): [0.6599750518798828, 0.31959402561187744, 0.3357643783092499, 0.36537182331085205, 0.32756492495536804, 0.44029995799064636, 0.27033835649490356, 0.4827609658241272, 0.26806002855300903, 0.4160946011543274, 0.21634532511234283, 0.6046504378318787, 0.43024319410324097, 0.28806763887405396, 0.3422259986400604, 0.3036242723464966, 0.39933544397354126, 0.3542373478412628, 0.447165846824646, 0.2515491247177124, 0.5521566271781921, 0.5290231108665466, 0.5080984830856323, 0.421864777803421, 0.45356306433677673, 0.6989604830741882, 0.43684709072113037, 0.46271616220474243, 0.457720011472702, 0.7005196809768677, 0.4537859857082367, 0.7930405139923096, 0.918682336807251, 0.7083240151405334, 0.9655430912971497, 0.7599287033081055, 0.7403987646102905, 0.7648483514785767, 0.7912213206291199, 0.652010440826416, 0.6089360117912292, 0.4702361524105072, 0.44323816895484924, 0.3977070450782776, 0.7281531095504761, 0.6235811710357666, 0.6788456439971924, 0.4875066578388214, 0.3023841977119446, 0.258270263671875, 0.2918223738670349, 0.39061057567596436, 0.3010035455226898, 0.43863964080810547, 0.346474826335907, 0.3749925196170807, 0.40312352776527405, 0.501847505569458, 0.29245325922966003, 0.32814401388168335, 0.45606595277786255, 0.38132065534591675, 0.29891306161880493, 0.3110980987548828, 0.3008301258087158, 0.2568182349205017, 0.26516398787498474, 0.2860548496246338, 0.35372039675712585, 0.5603305697441101, 0.36819422245025635, 0.42241916060447693, 0.3552458584308624, 0.5897090435028076, 0.2433948516845703, 0.44768980145454407, 0.2919379472732544, 0.24968811869621277, 0.27533912658691406, 0.23073825240135193, 0.36637863516807556, 0.2696627378463745, 0.3565974235534668, 0.48205071687698364, 0.29172483086586, 0.29013964533805847, 0.32133445143699646, 0.29968053102493286, 0.2766324579715729, 0.2395229935646057, 0.22934043407440186, 0.2523424029350281, 0.24817709624767303, 0.28184834122657776, 0.41276824474334717, 0.33434757590293884, 0.3144347667694092, 0.3479783236980438, 0.2727797031402588, 0.2406274676322937, 0.22619575262069702, 0.32575246691703796, 0.28658246994018555, 0.24428284168243408, 0.21518100798130035, 0.20254336297512054, 0.382632315158844, 0.5246217250823975, 0.3474041223526001, 0.24786989390850067, 0.31459471583366394, 0.40988707542419434, 0.2836333215236664, 0.35507604479789734, 0.3618931174278259, 0.34678199887275696, 0.308364599943161, 0.5079637765884399, 0.45459601283073425, 0.42369797825813293, 0.3319687843322754, 0.4082048535346985, 0.3772601783275604, 0.4668225646018982, 0.37924930453300476, 0.3775586783885956, 0.392621785402298, 0.4379982352256775, 0.5267438888549805, 0.4402915835380554, 0.4267708361148834, 0.40774476528167725, 0.42282211780548096, 0.4781952202320099, 0.5348532199859619, 0.5892888307571411, 0.40422841906547546, 0.4558173716068268, 0.36898142099380493, 0.9205887317657471, 0.6185564398765564, 0.5108629465103149, 0.7636951804161072, 0.4768432378768921, 0.8379564881324768, 0.3983945846557617, 0.4244348108768463, 0.377480149269104, 0.42376893758773804, 0.5243623852729797, 0.4117155075073242, 0.3847920894622803, 0.4880342185497284, 0.5619639754295349, 0.4753367304801941, 0.7594866156578064, 0.6272976994514465, 0.6687712073326111, 0.6804089546203613, 0.7493405342102051, 0.7857539653778076, 0.6962857246398926, 0.8121911287307739, 0.6477372646331787, 0.7025402784347534, 0.7654938697814941, 0.6455159783363342, 1.0511361360549927, 0.847347617149353, 0.609919011592865, 0.7612303495407104, 0.8660853505134583, 0.6686427593231201, 0.8942302465438843, 0.714823842048645, 0.8234495520591736, 1.2917590141296387, 0.9863629937171936, 0.8917754888534546, 0.802504301071167, 0.8014320731163025, 0.8109920024871826, 0.9435753226280212, 0.6368959546089172, 0.6719377636909485, 0.6416013836860657, 0.6162813305854797, 0.5613307952880859, 0.6572178602218628, 0.5976423621177673, 0.9072781801223755, 0.7019497752189636, 0.8401123881340027, 0.6700666546821594, 0.6234099864959717, 0.6980035901069641, 0.6150760054588318, 0.6935156583786011, 0.7264655828475952, 0.8167094588279724]
[INFO] curiosity (n=200, episode done=True, max steps reached=True): [0.20040760934352875, 0.405997633934021, 0.2014654278755188, 0.5143175721168518, 0.24975121021270752, 0.2219068706035614, 0.28917527198791504, 0.3457401990890503, 0.1857166886329651, 0.24595946073532104, 0.2544567584991455, 0.2575584053993225, 0.6467382907867432, 0.21939019858837128, 0.24156446754932404, 0.263954222202301, 0.456066757440567, 0.42061516642570496, 0.39988625049591064, 0.3835148811340332, 0.575833261013031, 0.529363214969635, 0.5236270427703857, 0.5616291165351868, 0.5749561786651611, 0.33099061250686646, 0.3211314082145691, 0.31990307569503784, 0.8756985068321228, 0.9123344421386719, 0.855627179145813, 0.8519073128700256, 0.8876305818557739, 1.366782546043396, 0.9131333827972412, 1.2058366537094116, 1.2662551403045654, 0.980659544467926, 1.1457749605178833, 0.8143234848976135, 0.7707265019416809, 0.6527854204177856, 0.4105456471443176, 0.7295986413955688, 0.8690494894981384, 0.6910908222198486, 0.641333818435669, 0.5049877762794495, 0.4368477761745453, 0.35646283626556396, 0.31686198711395264, 0.3444564938545227, 0.35278260707855225, 0.3804481327533722, 0.4642121493816376, 0.28114044666290283, 0.363090455532074, 0.2691594660282135, 0.31924164295196533, 0.28078794479370117, 0.359183132648468, 0.31364741921424866, 0.40449196100234985, 0.3378605246543884, 0.3605378270149231, 0.2715175449848175, 0.2720174789428711, 0.49601832032203674, 0.48536407947540283, 0.4551427364349365, 0.4952602982521057, 0.6198722124099731, 0.4677537679672241, 0.5097872614860535, 0.6203373074531555, 0.4208877980709076, 0.2833283841609955, 0.5943808555603027, 0.6143324971199036, 0.6355754733085632, 0.44470831751823425, 0.546419620513916, 0.5391665101051331, 0.5582544803619385, 0.5428563952445984, 0.6655526757240295, 0.562159538269043, 0.7829023599624634, 0.7467460632324219, 0.7492837309837341, 0.7788183093070984, 0.7620762586593628, 0.7492246627807617, 0.6904751658439636, 0.49599489569664, 0.5574407577514648, 0.6015672087669373, 0.3472583591938019, 0.3523492217063904, 0.4898557662963867, 0.49008071422576904, 0.543617308139801, 0.7521457076072693, 0.5986086130142212, 0.4664466977119446, 0.5208238363265991, 0.31878405809402466, 0.2425195872783661, 0.34365949034690857, 0.3240930736064911, 0.3329620063304901, 0.2607978284358978, 0.31962162256240845, 0.34391117095947266, 0.44687652587890625, 0.32213953137397766, 0.38312703371047974, 0.3649355173110962, 0.34303027391433716, 0.30650216341018677, 0.29493123292922974, 0.4825912117958069, 0.22040557861328125, 0.26829198002815247, 0.3844771683216095, 0.296333372592926, 0.241156205534935, 0.22011947631835938, 0.22701315581798553, 0.27934345602989197, 0.46820250153541565, 0.6166121959686279, 0.742115318775177, 0.7959161400794983, 0.5395683646202087, 0.5242422223091125, 0.5393686890602112, 0.8941488265991211, 0.43604591488838196, 0.28425338864326477, 0.3392718732357025, 0.23488372564315796, 0.3202020525932312, 0.27430957555770874, 0.24479073286056519, 0.2918567657470703, 0.603556215763092, 0.28601017594337463, 0.4155757427215576, 0.5102577209472656, 0.5628896355628967, 0.6641409397125244, 0.5935871601104736, 0.4754033088684082, 0.48434978723526, 0.3132435083389282, 0.38624298572540283, 0.3484439253807068, 0.5724354982376099, 0.8504922389984131, 0.6232136487960815, 0.633483350276947, 0.36520203948020935, 0.20835649967193604, 0.28730839490890503, 0.21950678527355194, 0.2056119292974472, 0.2782301604747772, 0.1878279596567154, 0.4237385392189026, 0.30017736554145813, 0.3848726451396942, 0.30528053641319275, 0.19472374022006989, 0.24799753725528717, 0.28432056307792664, 0.21214967966079712, 0.3687637448310852, 0.2690379321575165, 0.2466214895248413, 0.26589280366897583, 0.4564402997493744, 0.24240784347057343, 0.2574172616004944, 0.2114577293395996, 0.25341400504112244, 0.28775399923324585, 0.24803633987903595, 0.34510499238967896, 0.26538723707199097, 0.27149590849876404, 0.34061628580093384, 0.34328368306159973, 0.21388986706733704, 0.4521488845348358, 0.30640265345573425, 0.27976229786872864, 0.29255369305610657, 0.21736940741539001, 0.2564901113510132]
...

Note that I also provide information on “episode done” and “max steps reached” to identify successful episodes. If the first is false and the second is true, then the episode has finished because the move limit has been reached rather than the goal state - no successful run.