I’m new to the ML-Agents and was experimenting with the Karting Microgame tutorials. I would like to add imitation learning to my game and I was wondering how to make my Agent Kart take movement from the keyboard. Below is the Kart Agent script used by the agent karts.
So I saw in one of the tutorials that I should override a Heuristic function and add it to this script to allow me to use the arrow keys to record actions of my kart. But how do I do that with this script? I am struggling to find the things I should modify in order to make that work. Any tips would be much appreciated!
using KartGame.KartSystems;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using UnityEngine;
using Random = UnityEngine.Random;
namespace KartGame.AI
/// <summary>
/// Sensors hold information such as the position of rotation of the origin of the raycast and its hit threshold
/// to consider a "crash".
/// </summary>
public struct Sensor
public Transform Transform;
public float RayDistance;
public float HitValidationDistance;
/// <summary>
/// We only want certain behaviours when the agent runs.
/// Training would allow certain functions such as OnAgentReset() be called and execute, while Inferencing will
/// assume that the agent will continuously run and not reset.
/// </summary>
public enum AgentMode
/// <summary>
/// The KartAgent will drive the inputs for the KartController.
/// </summary>
public class KartAgent : Agent, IInput
#region Training Modes
[Tooltip("Are we training the agent or is the agent production ready?")]
public AgentMode Mode = AgentMode.Training;
[Tooltip("What is the initial checkpoint the agent will go to? This value is only for inferencing.")]
public ushort InitCheckpointIndex;
#region Senses
[Header("Observation Params")]
[Tooltip("What objects should the raycasts hit and detect?")]
public LayerMask Mask;
[Tooltip("Sensors contain ray information to sense out the world, you can have as many sensors as you need.")]
public Sensor[] Sensors;
[Header("Checkpoints"), Tooltip("What are the series of checkpoints for the agent to seek and pass through?")]
public Collider[] Colliders;
[Tooltip("What layer are the checkpoints on? This should be an exclusive layer for the agent to use.")]
public LayerMask CheckpointMask;
[Tooltip("Would the agent need a custom transform to be able to raycast and hit the track? " +
"If not assigned, then the root transform will be used.")]
public Transform AgentSensorTransform;
#region Rewards
[Header("Rewards"), Tooltip("What penatly is given when the agent crashes?")]
public float HitPenalty = -1f;
[Tooltip("How much reward is given when the agent successfully passes the checkpoints?")]
public float PassCheckpointReward;
[Tooltip("Should typically be a small value, but we reward the agent for moving in the right direction.")]
public float TowardsCheckpointReward;
[Tooltip("Typically if the agent moves faster, we want to reward it for finishing the track quickly.")]
public float SpeedReward;
[Tooltip("Reward the agent when it keeps accelerating")]
public float AccelerationReward;
#region ResetParams
[Header("Inference Reset Params")]
[Tooltip("What is the unique mask that the agent should detect when it falls out of the track?")]
public LayerMask OutOfBoundsMask;
[Tooltip("What are the layers we want to detect for the track and the ground?")]
public LayerMask TrackMask;
[Tooltip("How far should the ray be when casted? For larger karts - this value should be larger too.")]
public float GroundCastDistance;
#region Debugging
[Header("Debug Option")] [Tooltip("Should we visualize the rays that the agent draws?")]
public bool ShowRaycasts;
ArcadeKart m_Kart;
bool m_Acceleration;
bool m_Brake;
float m_Steering;
int m_CheckpointIndex;
bool m_EndEpisode;
float m_LastAccumulatedReward;
void Awake()
m_Kart = GetComponent<ArcadeKart>();
if (AgentSensorTransform == null) AgentSensorTransform = transform;
void Start()
// If the agent is training, then at the start of the simulation, pick a random checkpoint to train the agent.
if (Mode == AgentMode.Inferencing) m_CheckpointIndex = InitCheckpointIndex;
void Update()
if (m_EndEpisode)
m_EndEpisode = false;
void LateUpdate()
switch (Mode)
case AgentMode.Inferencing:
if (ShowRaycasts)
Debug.DrawRay(transform.position, Vector3.down * GroundCastDistance, Color.cyan);
// We want to place the agent back on the track if the agent happens to launch itself outside of the track.
if (Physics.Raycast(transform.position + Vector3.up, Vector3.down, out var hit, GroundCastDistance, TrackMask)
&& ((1 << hit.collider.gameObject.layer) & OutOfBoundsMask) > 0)
// Reset the agent back to its last known agent checkpoint
var checkpoint = Colliders[m_CheckpointIndex].transform;
transform.localRotation = checkpoint.rotation;
transform.position = checkpoint.position;
m_Kart.Rigidbody.velocity = default;
m_Steering = 0f;
m_Acceleration = m_Brake = false;
void OnTriggerEnter(Collider other)
var maskedValue = 1 << other.gameObject.layer;
var triggered = maskedValue & CheckpointMask;
FindCheckpointIndex(other, out var index);
// Ensure that the agent touched the checkpoint and the new index is greater than the m_CheckpointIndex.
if (triggered > 0 && index > m_CheckpointIndex || index == 0 && m_CheckpointIndex == Colliders.Length - 1)
m_CheckpointIndex = index;
void FindCheckpointIndex(Collider checkPoint, out int index)
for (int i = 0; i < Colliders.Length; i++)
if (Colliders[i].GetInstanceID() == checkPoint.GetInstanceID())
index = i;
index = -1;
float Sign(float value)
if (value > 0)
return 1;
if (value < 0)
return -1;
return 0;
public override void CollectObservations(VectorSensor sensor)
// Add an observation for direction of the agent to the next checkpoint.
var next = (m_CheckpointIndex + 1) % Colliders.Length;
var nextCollider = Colliders[next];
if (nextCollider == null)
var direction = (nextCollider.transform.position - m_Kart.transform.position).normalized;
sensor.AddObservation(Vector3.Dot(m_Kart.Rigidbody.velocity.normalized, direction));
if (ShowRaycasts)
Debug.DrawLine(AgentSensorTransform.position, nextCollider.transform.position, Color.magenta);
m_LastAccumulatedReward = 0.0f;
m_EndEpisode = false;
for (var i = 0; i < Sensors.Length; i++)
var current = Sensors[i];
var xform = current.Transform;
var hit = Physics.Raycast(AgentSensorTransform.position, xform.forward, out var hitInfo,
current.RayDistance, Mask, QueryTriggerInteraction.Ignore);
if (ShowRaycasts)
Debug.DrawRay(AgentSensorTransform.position, xform.forward * current.RayDistance, Color.green);
Debug.DrawRay(AgentSensorTransform.position, xform.forward * current.HitValidationDistance,
if (hit && hitInfo.distance < current.HitValidationDistance)
Debug.DrawRay(hitInfo.point, Vector3.up * 3.0f, Color.blue);
if (hit)
if (hitInfo.distance < current.HitValidationDistance)
m_LastAccumulatedReward += HitPenalty;
m_EndEpisode = true;
sensor.AddObservation(hit ? hitInfo.distance : current.RayDistance);
public override void OnActionReceived(float[] vectorAction)
// Find the next checkpoint when registering the current checkpoint that the agent has passed.
var next = (m_CheckpointIndex + 1) % Colliders.Length;
var nextCollider = Colliders[next];
var direction = (nextCollider.transform.position - m_Kart.transform.position).normalized;
var reward = Vector3.Dot(m_Kart.Rigidbody.velocity.normalized, direction);
if (ShowRaycasts) Debug.DrawRay(AgentSensorTransform.position, m_Kart.Rigidbody.velocity, Color.blue);
// Add rewards if the agent is heading in the right direction
AddReward(reward * TowardsCheckpointReward);
AddReward((m_Acceleration && !m_Brake ? 1.0f : 0.0f) * AccelerationReward);
AddReward(m_Kart.LocalSpeed() * SpeedReward);
public override void OnEpisodeBegin()
switch (Mode)
case AgentMode.Training:
m_CheckpointIndex = Random.Range(0, Colliders.Length - 1);
var collider = Colliders[m_CheckpointIndex];
transform.localRotation = collider.transform.rotation;
transform.position = collider.transform.position;
m_Kart.Rigidbody.velocity = default;
m_Acceleration = false;
m_Brake = false;
m_Steering = 0f;
void InterpretDiscreteActions(float[] actions)
m_Steering = actions[0] - 1f;
m_Acceleration = actions[1] >= 1.0f;
m_Brake = actions[1] < 1.0f;
public InputData GenerateInput()
return new InputData
Accelerate = m_Acceleration,
Brake = m_Brake,
TurnInput = m_Steering