MLAgents cars stop learn

Hello everyone, I am currently working on a self-driven car system usingMLAgent.
This is my environment:

  • Observations:
 {   
        //Car Speed
        Vector3 LocalVelocity = carRb.transform.InverseTransformDirection(carRb.velocity);
        carSpeed = LocalVelocity.z * 2.0f;
       

        //Car Direction respect next goal
        Vector3 checkPointForward = trackCheckpoints.GetNextCheckpoint(transform).transform.forward;
        float directionDot = Vector3.Dot(Vector3.Cross(checkPointForward.normalized, transform.forward.normalized), Vector3.up);


        //Next goal position
        Vector3 diff = trackCheckpoints.GetNextCheckpoint(transform).transform.position - transform.position;
        

        sensor.AddObservation(diff);
        sensor.AddObservation(directionDot);
        sensor.AddObservation(carSpeed);
    }

+ray sensor

  • Reward:
    -Add +1 each checkpoint
    -Set -1 on wall collision or wrong checkpoint

This is my config file:

behaviors:
  CarDriver:
    trainer_type: ppo
    hyperparameters:
      batch_size: 256
      buffer_size: 2048
      learning_rate: 0.0002
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
      learning_rate_schedule: constant
    network_settings:
      normalize: false
      hidden_units: 512
      num_layers: 2
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 0.8
      gail: 
        strength: 0.25
        use_actions: true
        demo_path: Demos/Finaltest.demo
        network_settings:
          normalize: false
          hidden_units: 512
          num_layers: 2
    keep_checkpoints: 5
    max_steps: 500000000
    time_horizon: 256
    summary_freq: 12000
    threaded: true

Agent Code:

using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
using static UnityEngine.GraphicsBuffer;

public class CarAgent : Agent
{

    [SerializeField] private TrackCheckpoints trackCheckpoints;
    [SerializeField] private Transform spawnPosition;
    [SerializeField] private Rigidbody carRb;
    private CarControllers carControllers;
    private float carSpeed;


    private void Awake()
    {
        carControllers = GetComponent<CarControllers>();
    }
    
    
    private void Start()
    {
        trackCheckpoints.OnCarEndTrack += TrackCheckpoints_OnCarEndTrack;
        trackCheckpoints.OnCarCorrectCheckpoint += TrackCheckpoints_OnCarCorrectCheckpoint;
        trackCheckpoints.OnCarWrongCheckpoint += TrackCheckpoints_OnCarWrongCheckpoint;
        
    }

    private void TrackCheckpoints_OnCarCorrectCheckpoint(object sender, TrackCheckpoints.CarCheckpointEventArgs e)
    {
        if (e.carTransform == transform)
        {
            //Debug.Log("+1");
            AddReward(+1f);
        }
       
    }

    private void TrackCheckpoints_OnCarEndTrack(object sender, TrackCheckpoints.CarCheckpointEventArgs e)
    {
        if (e.carTransform == transform)
        {
            Debug.Log("END");
            AddReward(+1f);
        }

    }

    private void TrackCheckpoints_OnCarWrongCheckpoint(object sender, TrackCheckpoints.CarCheckpointEventArgs e)
    {
        if (e.carTransform == transform)
        {
        
            AddReward(-1f);
        }
        EndEpisode();
    }

    public override void OnEpisodeBegin()
    {
     


        transform.position = spawnPosition.position; 
        transform.forward = spawnPosition.forward;
        trackCheckpoints.ResetCheckPoints(transform);
        carControllers.StopCompletly();


    }

    public override void CollectObservations(VectorSensor sensor)
    {   
        //Car Speed
        Vector3 LocalVelocity = carRb.transform.InverseTransformDirection(carRb.velocity);
        carSpeed = LocalVelocity.z * 2.0f;
       

        //Car Direction respect next goal
        Vector3 checkPointForward = trackCheckpoints.GetNextCheckpoint(transform).transform.forward;
        float directionDot = Vector3.Dot(Vector3.Cross(checkPointForward.normalized, transform.forward.normalized), Vector3.up);


        //Next goal position
        Vector3 diff = trackCheckpoints.GetNextCheckpoint(transform).transform.position - transform.position;
        

        sensor.AddObservation(diff);
        sensor.AddObservation(directionDot);
        sensor.AddObservation(carSpeed);
       

      
    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        
        float forwardAmount = actions.ContinuousActions[0];
        float turnAmount = actions.ContinuousActions[1];         
 

        carControllers.SetInputs(forwardAmount, turnAmount);//, boostButton);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        
        var continuousActionsOut = actionsOut.ContinuousActions;
        
        continuousActionsOut[0] = Input.GetAxis("Vertical");
        continuousActionsOut[1] = Input.GetAxis("Horizontal");

  

    }


    private void OnCollisionEnter(UnityEngine.Collision collision)
    {
        if (collision.gameObject.CompareTag("Wall"))
        {
           
            SetReward(-1f);
            

            EndEpisode();
        }
    }

   
    
}

My issue is that at a certain point, usually after 2M steps, even if the cars(5 agents are used simultaneously) were learning how to complete the track, it becomes brain rotten, forgetting everything and keeps colliding on the same wall close to the start.