I have training vehicle agent to follow a simple track. I am using behavioural cloning when I train the model I see it works fine but after that when I export it in Unity the agent just stuck.
my config file
default:
trainer: ppo
batch_size: 1024
beta: 5.0e-3
buffer_size: 10240
epsilon: 0.2
hidden_units: 128
lambd: 0.99
learning_rate: 3.0e-4
max_steps: 5000000
memory_size: 256
normalize: false
num_epoch: 3
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 10000
use_recurrent: false
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
RaceAgent:
summary_freq: 10000
time_horizon: 64
batch_size: 256
buffer_size: 2048
hidden_units: 128
num_layers: 2
beta: 5.0e-4
learning_rate_schedule: linear
max_steps: 5.0e7
num_epoch: 3
behavioral_cloning:
demo_path: RaceAgentN4_1.demo
strength: 1.0
steps: 150000
reward_signals:
extrinsic:
strength: 0.1
gamma: 0.99
curiosity:
strength: 0.01
gamma: 0.90
encoding_size: 256
gail:
strength: 1.0
gamma: 0.99
encoding_size: 128
demo_path: RaceAgentN4_1.demo
using System;
public class BikeAgent : Agent
{
[SerializeField] private Transform m_SpawnPos;
[SerializeField] private Vehicle _vehicle; // ref
[SerializeField] private TrackCheckpoints _trackCheckpoints; // ref
[SerializeField] private Transform m_BikeSphere;
public override void Initialize()
{
base.Initialize();
_trackCheckpoints.OnPlayerCorrectCheckpoint += OnCorrectCheckPoint;
_trackCheckpoints.OnPlayerWrongCheckpoint -= OnWrongCheckPoint;
_vehicle.StopVehicle = false;
}
//reward
void OnCorrectCheckPoint(Transform carTransform, bool isLapComplete)
{
//bike sphere
if(carTransform == this.m_BikeSphere)
{
AddReward(1f);
if (isLapComplete) AddReward(1f);
// print("Reward");
}
}
//punish
void OnWrongCheckPoint(Transform carTransform)
{
if (carTransform == this.m_BikeSphere)
{
AddReward(-1f);
}
}
public override void OnEpisodeBegin()
{
base.OnEpisodeBegin();
//reset vehicle
ResetVehicle();
}
void ResetVehicle()
{
_vehicle.StopVehicle = true;
Vector3 spwnPos = m_SpawnPos.position + new Vector3(x: Random.Range(-3f, 3f), 0.75f, Random.Range(-2f, 2f));
transform.position = spwnPos;
m_BikeSphere.position = spwnPos;
transform.forward = m_SpawnPos.forward;
m_BikeSphere.forward = m_SpawnPos.forward;
_trackCheckpoints.ResetCheckPoint(m_BikeSphere);
//todo reset checkpoint
}
//collect observation
public override void CollectObservations(VectorSensor sensor)
{
base.CollectObservations(sensor);
Vector3 checkPointForward = _trackCheckpoints.GetNextCheckPoint(this.m_BikeSphere).transform.forward;
float dirDot = Vector3.Dot(this.transform.forward, checkPointForward);
sensor.AddObservation(dirDot);
//m_SpawnPos.transform.position = _trackCheckpoints.GetPreviousCheckPoint(this.m_BikeSphere).position;
//print(dirDot);
}
//action received
public override void OnActionReceived(float[] vectorAction)
{
base.OnActionReceived(vectorAction);
//get off from the track
if (transform.position.y < 0f)
{
AddReward(-1f);
EndEpisode();
}
else
{
_vehicle.StopVehicle = false;
}
float forwardAmount = 0f;
float turnAmount = 0f;
forwardAmount = Mathf.FloorToInt(vectorAction[0]);
turnAmount = Mathf.FloorToInt(vectorAction[1]);
switch (forwardAmount)
{
case 0:
//idle
break;
case 1:
//forward
_vehicle.ControlAccelerate();
break;
case 2:
//backward
_vehicle.ControlBrake();
break;
}
switch (turnAmount)
{
case 0:
//idle
break;
case 1:
//left
_vehicle.ControlSteer(-1);
break;
case 2:
//right
_vehicle.ControlSteer(1);
break;
}
AddReward(-1f / MaxStep);
}
public override void Heuristic(float[] actionsOut)
{
base.Heuristic(actionsOut);
//default idle
actionsOut[0] = 0; // forward
actionsOut[1] = 0; // turn
//acclerate
if (Input.GetKey(KeyCode.W)) actionsOut[0] = 1;
// break
if (Input.GetKey(KeyCode.S)) actionsOut[0] = 2;
// turn left
if (Input.GetKey(KeyCode.A)) actionsOut[1] = 1;
// turn right
if (Input.GetKey(KeyCode.D)) actionsOut[1] = 2;
}
//todo collision obstacle reward etc
private void OnCollisionEnter(Collision other)
{
if (other.gameObject.CompareTag("wall"))
{
AddReward(-0.05f);
}
}
private void OnCollisionStay(Collision other)
{
if (other.gameObject.CompareTag("wall"))
{
AddReward(-0.01f);
}
}
// Update is called once per frame
void Update()
{
//todo update UI
}
}