Hi. Is there a chance to log the intrinsic curiosity reward (either ICM or RND) during training? In particular, I’m looking to logging the per-step reward, rather than cumulative reward over an episode. Is there e.g. a chance to overload the internal code for the curiosity reward? Thanks!
Here’s my solution, which in this form only works for a single agent/training process. It requires a modification of two files in the mlagents python-side code:
- trainers/trainer/rl_trainer.py:
class RLTrainer(Trainer):
"""
This class is the base class for trainers that use Reward Signals.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
# used for reporting only. We always want to report the environment reward to Tensorboard, regardless
# of what reward signals are actually present.
self.cumulative_returns_since_policy_update: List[float] = []
# ADDITION BEGIN
from collections import defaultdict
self.intrinsic_rewards = defaultdict(list)
# ADDITION END
self.collected_rewards: Dict[str, Dict[str, int]] = {
"environment": defaultdict(lambda: 0)
}
self.update_buffer: AgentBuffer = AgentBuffer()
self._stats_reporter.add_property(
StatsPropertyType.HYPERPARAMETERS, self.trainer_settings.as_dict()
)
- trainers/ppo/trainer.py (substitute with the trainer for your specific learner):
# Evaluate all reward functions
self.collected_rewards["environment"][agent_id] += np.sum(
agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS]
)
for name, reward_signal in self.optimizer.reward_signals.items():
raw_rewards = reward_signal.evaluate(agent_buffer_trajectory)
evaluate_result = (raw_rewards * reward_signal.strength)
agent_buffer_trajectory[RewardSignalUtil.rewards_key(name)].extend(
evaluate_result
)
# Report the reward signals
self.collected_rewards[name][agent_id] += np.sum(evaluate_result)
# ADDITION BEGIN
if name == "curiosity" or name == "rnd":
self.intrinsic_rewards[name].append(raw_rewards.tolist())
# ADDITION END
# If this was a terminal trajectory, append stats and reset reward collection
if trajectory.done_reached:
self._update_end_episode_stats(agent_id, self.optimizer)
# ADDITION BEGIN
for reward_type, rewards in self.intrinsic_rewards.items():
reward_list = list(chain.from_iterable(rewards))
logger.info("%s (n=%d, episode done=%s, max steps reached=%s): %s" % (reward_type, len(reward_list), trajectory.done_reached, trajectory.interrupted, reward_list))
self.intrinsic_rewards[reward_type] = [] # reset list
# ADDITION END
Produced output on the Hummingbird tutorial with curiosity reward added to the config:
[INFO] curiosity (n=200, episode done=True, max steps reached=True): [0.6599750518798828, 0.31959402561187744, 0.3357643783092499, 0.36537182331085205, 0.32756492495536804, 0.44029995799064636, 0.27033835649490356, 0.4827609658241272, 0.26806002855300903, 0.4160946011543274, 0.21634532511234283, 0.6046504378318787, 0.43024319410324097, 0.28806763887405396, 0.3422259986400604, 0.3036242723464966, 0.39933544397354126, 0.3542373478412628, 0.447165846824646, 0.2515491247177124, 0.5521566271781921, 0.5290231108665466, 0.5080984830856323, 0.421864777803421, 0.45356306433677673, 0.6989604830741882, 0.43684709072113037, 0.46271616220474243, 0.457720011472702, 0.7005196809768677, 0.4537859857082367, 0.7930405139923096, 0.918682336807251, 0.7083240151405334, 0.9655430912971497, 0.7599287033081055, 0.7403987646102905, 0.7648483514785767, 0.7912213206291199, 0.652010440826416, 0.6089360117912292, 0.4702361524105072, 0.44323816895484924, 0.3977070450782776, 0.7281531095504761, 0.6235811710357666, 0.6788456439971924, 0.4875066578388214, 0.3023841977119446, 0.258270263671875, 0.2918223738670349, 0.39061057567596436, 0.3010035455226898, 0.43863964080810547, 0.346474826335907, 0.3749925196170807, 0.40312352776527405, 0.501847505569458, 0.29245325922966003, 0.32814401388168335, 0.45606595277786255, 0.38132065534591675, 0.29891306161880493, 0.3110980987548828, 0.3008301258087158, 0.2568182349205017, 0.26516398787498474, 0.2860548496246338, 0.35372039675712585, 0.5603305697441101, 0.36819422245025635, 0.42241916060447693, 0.3552458584308624, 0.5897090435028076, 0.2433948516845703, 0.44768980145454407, 0.2919379472732544, 0.24968811869621277, 0.27533912658691406, 0.23073825240135193, 0.36637863516807556, 0.2696627378463745, 0.3565974235534668, 0.48205071687698364, 0.29172483086586, 0.29013964533805847, 0.32133445143699646, 0.29968053102493286, 0.2766324579715729, 0.2395229935646057, 0.22934043407440186, 0.2523424029350281, 0.24817709624767303, 0.28184834122657776, 0.41276824474334717, 0.33434757590293884, 0.3144347667694092, 0.3479783236980438, 0.2727797031402588, 0.2406274676322937, 0.22619575262069702, 0.32575246691703796, 0.28658246994018555, 0.24428284168243408, 0.21518100798130035, 0.20254336297512054, 0.382632315158844, 0.5246217250823975, 0.3474041223526001, 0.24786989390850067, 0.31459471583366394, 0.40988707542419434, 0.2836333215236664, 0.35507604479789734, 0.3618931174278259, 0.34678199887275696, 0.308364599943161, 0.5079637765884399, 0.45459601283073425, 0.42369797825813293, 0.3319687843322754, 0.4082048535346985, 0.3772601783275604, 0.4668225646018982, 0.37924930453300476, 0.3775586783885956, 0.392621785402298, 0.4379982352256775, 0.5267438888549805, 0.4402915835380554, 0.4267708361148834, 0.40774476528167725, 0.42282211780548096, 0.4781952202320099, 0.5348532199859619, 0.5892888307571411, 0.40422841906547546, 0.4558173716068268, 0.36898142099380493, 0.9205887317657471, 0.6185564398765564, 0.5108629465103149, 0.7636951804161072, 0.4768432378768921, 0.8379564881324768, 0.3983945846557617, 0.4244348108768463, 0.377480149269104, 0.42376893758773804, 0.5243623852729797, 0.4117155075073242, 0.3847920894622803, 0.4880342185497284, 0.5619639754295349, 0.4753367304801941, 0.7594866156578064, 0.6272976994514465, 0.6687712073326111, 0.6804089546203613, 0.7493405342102051, 0.7857539653778076, 0.6962857246398926, 0.8121911287307739, 0.6477372646331787, 0.7025402784347534, 0.7654938697814941, 0.6455159783363342, 1.0511361360549927, 0.847347617149353, 0.609919011592865, 0.7612303495407104, 0.8660853505134583, 0.6686427593231201, 0.8942302465438843, 0.714823842048645, 0.8234495520591736, 1.2917590141296387, 0.9863629937171936, 0.8917754888534546, 0.802504301071167, 0.8014320731163025, 0.8109920024871826, 0.9435753226280212, 0.6368959546089172, 0.6719377636909485, 0.6416013836860657, 0.6162813305854797, 0.5613307952880859, 0.6572178602218628, 0.5976423621177673, 0.9072781801223755, 0.7019497752189636, 0.8401123881340027, 0.6700666546821594, 0.6234099864959717, 0.6980035901069641, 0.6150760054588318, 0.6935156583786011, 0.7264655828475952, 0.8167094588279724]
[INFO] curiosity (n=200, episode done=True, max steps reached=True): [0.20040760934352875, 0.405997633934021, 0.2014654278755188, 0.5143175721168518, 0.24975121021270752, 0.2219068706035614, 0.28917527198791504, 0.3457401990890503, 0.1857166886329651, 0.24595946073532104, 0.2544567584991455, 0.2575584053993225, 0.6467382907867432, 0.21939019858837128, 0.24156446754932404, 0.263954222202301, 0.456066757440567, 0.42061516642570496, 0.39988625049591064, 0.3835148811340332, 0.575833261013031, 0.529363214969635, 0.5236270427703857, 0.5616291165351868, 0.5749561786651611, 0.33099061250686646, 0.3211314082145691, 0.31990307569503784, 0.8756985068321228, 0.9123344421386719, 0.855627179145813, 0.8519073128700256, 0.8876305818557739, 1.366782546043396, 0.9131333827972412, 1.2058366537094116, 1.2662551403045654, 0.980659544467926, 1.1457749605178833, 0.8143234848976135, 0.7707265019416809, 0.6527854204177856, 0.4105456471443176, 0.7295986413955688, 0.8690494894981384, 0.6910908222198486, 0.641333818435669, 0.5049877762794495, 0.4368477761745453, 0.35646283626556396, 0.31686198711395264, 0.3444564938545227, 0.35278260707855225, 0.3804481327533722, 0.4642121493816376, 0.28114044666290283, 0.363090455532074, 0.2691594660282135, 0.31924164295196533, 0.28078794479370117, 0.359183132648468, 0.31364741921424866, 0.40449196100234985, 0.3378605246543884, 0.3605378270149231, 0.2715175449848175, 0.2720174789428711, 0.49601832032203674, 0.48536407947540283, 0.4551427364349365, 0.4952602982521057, 0.6198722124099731, 0.4677537679672241, 0.5097872614860535, 0.6203373074531555, 0.4208877980709076, 0.2833283841609955, 0.5943808555603027, 0.6143324971199036, 0.6355754733085632, 0.44470831751823425, 0.546419620513916, 0.5391665101051331, 0.5582544803619385, 0.5428563952445984, 0.6655526757240295, 0.562159538269043, 0.7829023599624634, 0.7467460632324219, 0.7492837309837341, 0.7788183093070984, 0.7620762586593628, 0.7492246627807617, 0.6904751658439636, 0.49599489569664, 0.5574407577514648, 0.6015672087669373, 0.3472583591938019, 0.3523492217063904, 0.4898557662963867, 0.49008071422576904, 0.543617308139801, 0.7521457076072693, 0.5986086130142212, 0.4664466977119446, 0.5208238363265991, 0.31878405809402466, 0.2425195872783661, 0.34365949034690857, 0.3240930736064911, 0.3329620063304901, 0.2607978284358978, 0.31962162256240845, 0.34391117095947266, 0.44687652587890625, 0.32213953137397766, 0.38312703371047974, 0.3649355173110962, 0.34303027391433716, 0.30650216341018677, 0.29493123292922974, 0.4825912117958069, 0.22040557861328125, 0.26829198002815247, 0.3844771683216095, 0.296333372592926, 0.241156205534935, 0.22011947631835938, 0.22701315581798553, 0.27934345602989197, 0.46820250153541565, 0.6166121959686279, 0.742115318775177, 0.7959161400794983, 0.5395683646202087, 0.5242422223091125, 0.5393686890602112, 0.8941488265991211, 0.43604591488838196, 0.28425338864326477, 0.3392718732357025, 0.23488372564315796, 0.3202020525932312, 0.27430957555770874, 0.24479073286056519, 0.2918567657470703, 0.603556215763092, 0.28601017594337463, 0.4155757427215576, 0.5102577209472656, 0.5628896355628967, 0.6641409397125244, 0.5935871601104736, 0.4754033088684082, 0.48434978723526, 0.3132435083389282, 0.38624298572540283, 0.3484439253807068, 0.5724354982376099, 0.8504922389984131, 0.6232136487960815, 0.633483350276947, 0.36520203948020935, 0.20835649967193604, 0.28730839490890503, 0.21950678527355194, 0.2056119292974472, 0.2782301604747772, 0.1878279596567154, 0.4237385392189026, 0.30017736554145813, 0.3848726451396942, 0.30528053641319275, 0.19472374022006989, 0.24799753725528717, 0.28432056307792664, 0.21214967966079712, 0.3687637448310852, 0.2690379321575165, 0.2466214895248413, 0.26589280366897583, 0.4564402997493744, 0.24240784347057343, 0.2574172616004944, 0.2114577293395996, 0.25341400504112244, 0.28775399923324585, 0.24803633987903595, 0.34510499238967896, 0.26538723707199097, 0.27149590849876404, 0.34061628580093384, 0.34328368306159973, 0.21388986706733704, 0.4521488845348358, 0.30640265345573425, 0.27976229786872864, 0.29255369305610657, 0.21736940741539001, 0.2564901113510132]
...
Note that I also provide information on “episode done” and “max steps reached” to identify successful episodes. If the first is false and the second is true, then the episode has finished because the move limit has been reached rather than the goal state - no successful run.