[ML-Agents] Best Training Method for cars using Wheel Collider to navigate a track

Hi, I am trying to train 2 cars which rely on WheelColliders for movement and steering. I wanted to make them close to a F1 car so the acceleration is fast and the handling is restricted, to about 15f. I want to find the optimum way to train them.
I initially used Transform and Relative Force before and I was able to train my cars well to navigate 2 tracks almost to perfection, as in they’d complete the track 98 out of 100 times.I was following this method where I train them to go forward, then turns and then start them on the track, section by section and using --intialize-from to use the previous brain too to aid in training.

This time with Wheel Colliders I am using GAIL and BC too, and 2 cars on one track. However the cars crash more often into walls and into each other and due to the handling being restricted than before, they struggle to reverse and realign themselves. Below is my code for my CarController integrated with ML-Agents and my YAML file. Any advice or help would be appreciated and do let me know if any more information is needed. I am open to joining a meeting too and sharing my work if that’s more convenient.

> public class CarControllerImproved : Agent
> {
>     // Start is called before the first frame update
>     [SerializeField] WheelCollider frontLeft;
>     [SerializeField] WheelCollider frontRight;
>     [SerializeField] WheelCollider rearLeft;
>     [SerializeField] WheelCollider rearRight;
>     [SerializeField] Transform frontRightTransform;
>     [SerializeField] Transform frontLeftTransform;
>     [SerializeField] Transform rearRightTransform;
>     [SerializeField] Transform rearLeftTransform;
> 
> 
> 
>     public float maxAcceleration = 500f;
> 
>     public float maxTorque = 500f;
> 
>     public float brakingforce = 600f;
>     public float maxTurnAngle = 30f;
>     public float maxSpeed = 120f;
> 
>     public Vector3 centreOfMass;
>     
>     private float currentAcceleration = 0f;
>     private float currentBrakeForce = 0f;
>     
>     //Turning
>     private float moveInput, steerInput;
>     private float turnSensitivity = 0.5f;
>     private float _steerAngle = 0f;
>     
>     Rigidbody rb;
> 
> //MLAgents
>     public float multfwd;   //forward reward
>     public float multback; //backward reward
>     private Vector3 recall_position;            //spawn position
>     private Quaternion recall_rotation;
>     public bool doEpisodes = true;
>     private int currentCarPosition, previousCarPosition;
>     public float BetterLapReward = 0.0f;
>     public List<GameObject> checkPointList, allCheckpointList;
>     GameObject nextCheckpoint;
> 
> 
>     public override void Initialize()
>     {
>         rb = GetComponent<Rigidbody>();
>         rb.centerOfMass = centreOfMass;
>         checkPointList = new List<GameObject>();
>         recall_position = new Vector3(this.transform.position.x, this.transform.position.y, this.transform.position.z);
>         recall_rotation = new Quaternion(this.transform.rotation.x, this.transform.rotation.y, this.transform.rotation.z, this.transform.rotation.w);
>         allCheckpointList = new List<GameObject>();
>     }
>     public override void OnEpisodeBegin()
>     {
>         rb.velocity = Vector3.zero;
>         this.transform.position = recall_position;
>         this.transform.rotation = recall_rotation;
>         allCheckpointList = RaceManager.Instance.monzaCheckpoints;
>         nextCheckpoint = allCheckpointList[0];
>         checkPointList.Clear();
>         previousCarPosition = RaceManager.Instance.GetPosition(gameObject);
>     }
> 
>     public override void OnActionReceived(ActionBuffers actions)
>     {
>         float mag = rb.velocity.sqrMagnitude;
>         float steerDirection = 0f;
>         float alignmentDot = Vector3.Dot(-transform.forward, nextCheckpoint.gameObject.transform.forward);
>         switch (actions.DiscreteActions.Array[0])
>         {
>             case 0:
>                 rearLeft.motorTorque = 0;
>                 rearRight.motorTorque = 0;
>                 frontLeft.motorTorque = 0;
>                 frontRight.motorTorque = 0;
>                 break;
>             case 1:
>                 rearLeft.motorTorque = maxTorque;
>                 rearRight.motorTorque = maxTorque;
>                 frontLeft.motorTorque = maxTorque;
>                 frontRight.motorTorque = maxTorque;
>                 AddReward(multback);
>                 break;
>             case 2:
>                 rearLeft.motorTorque = -maxTorque;
>                 rearRight.motorTorque = -maxTorque;
>                 frontLeft.motorTorque = -maxTorque;
>                 frontRight.motorTorque = -maxTorque;
>                 AddReward(multfwd);
>                 break;
>         }
> 
>         //float steerInput = actions.ContinuousActions[0];
> 
>         switch (actions.DiscreteActions.Array[1])
>         {
>             case 0:
>                 steerDirection = 0f;
>                 break;
>             case 1:
>                 steerDirection = -1f;//left
>                 if (alignmentDot > 0.8f)
>                     AddReward(alignmentDot/10000f);
>                 break;
>             case 2:
>                 steerDirection = 1f; //right
>                 if (alignmentDot > 0.8f)
>                     AddReward(alignmentDot/10000f);
>                 break;
>         }
> 
>         _steerAngle = steerDirection * turnSensitivity * maxTurnAngle;
>         frontLeft.steerAngle = Mathf.Lerp(frontLeft.steerAngle, _steerAngle, 0.6f);
>         frontRight.steerAngle = Mathf.Lerp(frontRight.steerAngle, _steerAngle, 0.6f); //right
> 
>         switch (actions.DiscreteActions.Array[2])
>         {
>             case 0:
>                 currentBrakeForce = 0f;
>                 break;
>             case 1:
>                 currentBrakeForce = brakingforce * 200;
>                 break;
>         }
> 
>         frontLeft.brakeTorque = currentBrakeForce;
>         frontRight.brakeTorque = currentBrakeForce;
>         rearLeft.brakeTorque = currentBrakeForce;
>         rearRight.brakeTorque = currentBrakeForce;
>     }
> 
>     public override void Heuristic(in ActionBuffers actionsOut)
>     {
>         actionsOut.DiscreteActions.Array[0] = 0;
>         //var continuousactions = actionsOut.ContinuousActions;
>         actionsOut.DiscreteActions.Array[1] = 0;
>         actionsOut.DiscreteActions.Array[2] = 0;
> 
> 
>         moveInput = -Input.GetAxis("Vertical");
>         steerInput = Input.GetAxis("Horizontal");
> 
>         if (moveInput > 0)
>             actionsOut.DiscreteActions.Array[0] = 1;    //back
>         else if (moveInput < 0)
>             actionsOut.DiscreteActions.Array[0] = 2;    //forward
> 
>         //continuousactions[0] =  steerInput;
> 
>         if (steerInput < 0)
>         {
>             actionsOut.DiscreteActions.Array[1] = 1;
>         } //left
>         else if (steerInput > 0)
>         {
>             actionsOut.DiscreteActions.Array[1] = 2;
>         } //right
> 
>         if (Input.GetKey(KeyCode.Space))
>             actionsOut.DiscreteActions.Array[2] = 1;
>         else
>             actionsOut.DiscreteActions.Array[2] = 0;
>     }
> 
>     public void OnTriggerEnter(Collider other)
>     {
>         float directionDot;
>         BetterLapReward = gameObject.GetComponent<LapTimer>().BetterTime();
> 
>         if (other.gameObject.tag == "Checkpoint")
>         {
>             directionDot = Vector3.Dot(transform.forward, other.gameObject.transform.forward);
> 
>             if (directionDot < 0)
>             {
>                 if (!checkPointList.Contains(other.gameObject))
>                 {
>                     checkPointList.Add(other.gameObject);
>                     AddReward(3.0f);
>                     nextCheckpoint = allCheckpointList[checkPointList.Count];
>                     Debug.Log("Checkpoint reached");
>                 }
>                 else
>                 {
>                     AddReward(0.05f);
>                     Debug.Log("Passing same checkpoint");
>                 }
>             }
>             else
>             {
>                 AddReward(-5.0f);
>                 Debug.Log("Wrong checkpoint, start over");
>                 checkPointList.Remove(other.gameObject);
>             }
>         }
>         if (other.gameObject.tag == "Final")
>         {
>             Debug.Log("The agent has completed the track " + BetterLapReward);
>             //AddReward(1.5f - ((float)currentCarPosition/2.0f));
>             AddReward(8.0f - (float)BetterLapReward - ((float)currentCarPosition / 3.0f));
>             //AddReward(5.0f);
>             checkPointList.Clear();
>             EndEpisode();
>             OnEpisodeBegin();
>         }
>         Debug.Log(currentCarPosition);
>     }
> 
>     public void OnCollisionEnter(Collision collision)
>     {
> 
>         if (collision.gameObject.tag == "Wall")
>         {
>             AddReward(-0.2f);
>         }
> 
>         if (collision.gameObject.tag == "Car")
>         {
>             AddReward(-0.01f);
>         }
> 
>         if (collision.gameObject.tag == "StartLine")
>         {
>             Debug.Log("Reached the start line again, turn around");
>             AddReward(-2.0f);
>             checkPointList.Clear();
>             EndEpisode();
>             OnEpisodeBegin();
> 
>         }
>     }
> 
>     public void OnCollisionStay(Collision collision)
>     {
>         if (collision.gameObject.tag == "Wall")
>         {
>             AddReward(-0.01f);
>         }
>         if (collision.gameObject.tag == "Car")
>         {
>             AddReward(-0.01f);
>         }
>     }
> 
>     void Update()
>     {
>         //GetInputs();
>         currentCarPosition = RaceManager.Instance.GetPosition(gameObject);
> 
>         if (currentCarPosition < previousCarPosition)
>         {
>             AddReward(2.0f);
>             Debug.Log($"Overtake detected! New Position: {currentCarPosition}, Reward Given.");
> 
>             previousCarPosition = currentCarPosition;
>         }
>         else if (currentCarPosition > previousCarPosition)
>         {
>             AddReward(-1.0f);
>             Debug.Log($"Position dropped. Current Position: {currentCarPosition}, Penalty Given.");
> 
>             previousCarPosition = currentCarPosition;
>         }
>     }   

YAML File

> default_settings: null
> behaviors:
>   F1Forward:
>     trainer_type: ppo
>     hyperparameters:
>       batch_size: 1024
>       buffer_size: 10240
>       learning_rate: 0.0003
>       beta: 0.005
>       epsilon: 0.2
>       lambd: 0.95
>       num_epoch: 3
>       shared_critic: false
>       learning_rate_schedule: linear
>       beta_schedule: linear
>       epsilon_schedule: linear
>     network_settings:
>       normalize: true
>       hidden_units: 128
>       num_layers: 2
>       vis_encode_type: simple
>       memory: null
>       goal_conditioning_type: hyper
>       deterministic: false
>     reward_signals:
>       extrinsic:
>         gamma: 0.99
>         strength: 1.0
>         network_settings:
>           normalize: true
>           hidden_units: 128
>           num_layers: 2
>           vis_encode_type: simple
>           memory: null
>           goal_conditioning_type: hyper
>           deterministic: false
>       gail:
>         strength: 0.5
>         demo_path: Demos\MonzaS2.demo
>     behavioral_cloning:
>       strength: 1.0
>       demo_path: Demos\MonzaS2.demo
>     init_path: null
>     keep_checkpoints: 5
>     checkpoint_interval: 700000
>     max_steps: 700000
>     time_horizon: 64
>     summary_freq: 50000
>     threaded: false
>     self_play: null
> env_settings:
>   env_path: null
>   env_args: null
>   base_port: 5005
>   num_envs: 1
>   num_areas: 1
>   seed: -1
>   max_lifetime_restarts: 10
>   restarts_rate_limit_n: 1
>   restarts_rate_limit_period_s: 60
> engine_settings:
>   width: 84
>   height: 84
>   quality_level: 5
>   time_scale: 20
>   target_frame_rate: -1
>   capture_frame_rate: 60
>   no_graphics: false
> environment_parameters: null
> checkpoint_settings:
>   run_id: F1Forward
>   initialize_from: null
>   load_model: false
>   resume: false
>   force: false
>   train_model: true
>   inference: false
>   results_dir: results
> torch_settings:
>   device: null
> debug: false