Hello there,
I am looking for some insights on my rewards for seeker agent. I want him to be able to find hider in maze-like environment. I am using Raycasts (Ray Perception Sensor) as observation of environment. Currently, seeker is rewarded if he “sees” with one of rays a hider, and gets constant reward for going towards him, and negative reward for going away from him. Also he gets a reward when he touch hider, and negative reward when he collides with wall.
As it is, i am not satisfied with how the seeker is working, as now it seems that he prioritize going in circles around a map, and not prioritizing the behavior in which he goes towards a hider as soon as he sees him, which is how he should act.
Here is the code, let me know if you need anything else. All help will be appreciated.
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
public class SeekerAgent : Agent
{
[SerializeField] private float rotationSpeed;
[SerializeField] private float movementSpeed;
[SerializeField] private RayPerceptionSensorComponent3D rayPerceptionSensor;
[SerializeField] private HiderAgent hiderAgent;
[SerializeField] private MeshRenderer groundMesh;
[SerializeField] private Material winMaterial;
[SerializeField] private Material defaultMaterial;
private int episodeCounter = 0;
private Rigidbody seekerRb;
private Vector3 positionOfHider;
private bool hiderFound;
private float previousDistanceToHider = float.MaxValue;
public override void Initialize()
{
seekerRb = GetComponent<Rigidbody>();
}
public override void OnEpisodeBegin()
{
episodeCounter++;
hiderFound = false;
positionOfHider = Vector3.zero;
seekerRb.velocity = Vector3.zero;
seekerRb.angularVelocity = Vector3.zero;
transform.localPosition = new Vector3(Random.Range(-6, 4), 0.25f, 3.38f);
}
public override void OnActionReceived(ActionBuffers actions)
{
AddReward(-1f / MaxStep);
MoveAgent(actions.DiscreteActions);
if (hiderFound)
{
float currentDistanceToHider = Vector3.Distance(transform.localPosition, positionOfHider);
if (currentDistanceToHider < previousDistanceToHider)
AddReward(1 / currentDistanceToHider);
else
AddReward(-1 / currentDistanceToHider);
previousDistanceToHider = currentDistanceToHider;
}
}
private void MoveAgent(ActionSegment<int> actions)
{
var moveDirection = Vector3.zero;
var rotateDirection = Vector3.zero;
var actionsMoving = actions[0];
var actionsRotating = actions[1];
switch (actionsMoving)
{
case 1:
moveDirection = movementSpeed * Time.deltaTime * Vector3.forward;
break;
case 2:
moveDirection = movementSpeed * Time.deltaTime * Vector3.back;
break;
}
if (actionsRotating == 1)
{
rotateDirection = Vector3.up * rotationSpeed * Time.deltaTime;
}
else if (actionsRotating == 2)
{
rotateDirection = Vector3.up * -rotationSpeed * Time.deltaTime;
}
transform.Rotate(rotateDirection);
transform.Translate(moveDirection, Space.Self);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
if (Input.GetKey(KeyCode.D))
discreteActionsOut[1] = 1;
else if (Input.GetKey(KeyCode.A))
discreteActionsOut[1] = 2;
if (Input.GetKey(KeyCode.W))
discreteActionsOut[0] = 1;
else if (Input.GetKey(KeyCode.S))
discreteActionsOut[0] = 2;
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(transform.localPosition);
sensor.AddObservation(transform.localRotation);
if (!hiderFound)
foreach (var raySensor in rayPerceptionSensor.RaySensor.RayPerceptionOutput.RayOutputs)
{
if (raySensor.HitGameObject != null && raySensor.HitGameObject.CompareTag("Hider"))
{
AddReward(0.1f);
positionOfHider = raySensor.HitGameObject.transform.localPosition;
hiderFound = true;
break;
}
}
if (hiderFound)
sensor.AddObservation(positionOfHider);
else
sensor.AddObservation(Vector3.zero);
}
private void OnCollisionEnter(Collision other)
{
if (other.gameObject.CompareTag("Wall"))
{
AddReward(-0.1f);
}
if (other.gameObject.CompareTag("Hider"))
{
AddReward(3f);
if (episodeCounter % 2 == 0)
groundMesh.material = winMaterial;
else
groundMesh.material = defaultMaterial;
hiderAgent.EndEpisode();
EndEpisode();
}
}
}