Hi,
Im trying to train a car driving agent to go around a simple track. My current goal for the Agent is to make it able to tell the difference between a right and a left turn (90° turns). The problem is that the Agent always seems to learn how to turn left much faster, and then it didnt even try to make the right turns at all.
The interesesting part is that if I only train the Agent with the track containing only right turns, it learns it no problem, but if I add the other track as well, eventually the Agent will learn to turn left, and forget how to turn right.
Here is a screenshot of the environment from above:
Each track has 8 individual car Agent.
For observations i use RayPerceptionSensor3D-s to detect the walls, and I have 5 vector observations in addition:
- The magnitude of the velocity vector
- The dot product of the car’s forward vector and the next checkpoint’s forward vector
- The dot product of the car’s cross vector (cross product of up & forward) and the next checkpoint’s forward vector
- The previous 2 but for the 2nd next Checkpoint
So my idea was that using the dot products the Agent should be able to identify what kind of turn is coming. The reason why I use dot products instead of the actual vectors is that the dot product is independent from the vectors orientation in world space (So every right turn will look the same for the Agent). I thought this is much easier for the Agent to learn because if the dot product with the cross vector is negative, it should turn right, if its positive, it should turn left.
The rest of the project is quite basic, but here is the source code of the Agent:
using System.Collections;
using System.Collections.Generic;
using Unity.VisualScripting;
using UnityEngine;
using Unity.MLAgents;
using static UnityEngine.GraphicsBuffer;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.Barracuda;
using Unity.Mathematics;
public class Player : Agent
{
private CarController carController;
[SerializeField] private Rigidbody carRB;
[SerializeField] private Transform carTransform;
[SerializeField] private TrackScript Track;
[SerializeField] private InputController inputController;
private int nextCheckPoint = 0;
public override void Initialize()
{
carController = GetComponent<CarController>();
}
public override void OnEpisodeBegin()
{
transform.localPosition = new Vector3(0f, 0.5f, -5f);
nextCheckPoint = 0;
carRB.velocity = new Vector3(0f, 0f, 0f);
transform.localRotation = Quaternion.Euler(new Vector3(0f, 180f, 0f));
}
public override void CollectObservations(VectorSensor sensor)
{
Vector3 CrossVec = Vector3.Cross(carTransform.forward, carTransform.up);
Vector3 CheckPointForward = Track.GetCheckPointTransform(nextCheckPoint).forward;
sensor.AddObservation(Vector3.Magnitude(carRB.velocity));
sensor.AddObservation(Vector3.Dot(CheckPointForward , carTransform.forward));
sensor.AddObservation(Vector3.Dot(CheckPointForward, CrossVec));
CheckPointForward = Track.GetCheckPointTransform(nextCheckPoint + 1).forward;
sensor.AddObservation(Vector3.Dot(CheckPointForward, carTransform.forward));
sensor.AddObservation(Vector3.Dot(CheckPointForward, CrossVec));
}
public override void OnActionReceived(ActionBuffers actions)
{
AddReward(-10f / MaxStep);
Vector3 CheckPointForward = Vector3.Normalize(Track.GetCheckPointTransform(nextCheckPoint).forward);
if (Vector3.Dot(CheckPointForward, Vector3.Normalize(carTransform.forward)) < 0)
{
AddReward(-0.1f);
}
carController.Throttle = actions.ContinuousActions[0];
carController.Steer = actions.ContinuousActions[1];
}
public override void Heuristic(in ActionBuffers actionsOut)
{
ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
continuousActions[1] = inputController.SteerInput;
continuousActions[0] = inputController.ThrottleInput;
}
private void OnCollisionEnter(Collision collision)
{
AddReward(-1f);
}
private void OnCollisionStay(Collision collision)
{
AddReward(-0.05f);
}
private void OnTriggerEnter(Collider other)
{
if(other.tag=="CheckPoint")
{
int newID = Track.CrossCheckPoint(nextCheckPoint, other.transform);
if(newID != nextCheckPoint)
{
nextCheckPoint = newID;
AddReward(2f);
}
}
}
}
And here is the config file:
behaviors:
DriveCar:
trainer_type: ppo
hyperparameters:
batch_size: 2048
buffer_size: 40960
learning_rate: 0.0003
beta: 0.003
epsilon: 0.2
lambd: 0.95
num_epoch: 4
learning_rate_schedule: linear
network_settings:
hidden_units: 512
num_layers: 4
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.995
strength: 1.0
keep_checkpoints: 10
checkpoint_interval: 1000000
max_steps: 50000000
time_horizon: 512
summary_freq: 50000
If you have any idea why is this happening, i would be gratefull if you would share it with me.
I checked and all the checkpoints are facing the correct way, so thats not the problem.