I don’t actually use the goals position as an observation, I use the goals positions relative to the agents reference frame.
sensor.AddObservation(this.transform.InverseTransformPoint(goal.transform.position));
This is like the difference between giving someone directions (left at the stoplight then right at the blue sign) and just telling them the destination (go to the gas station).
This greatly simplifies the problem compared to searching for it via raycast hits, although I’m fairly confident we’ll be able to train without it later.
Here’s the agent script with my changes - I tried to put “CHANGED” next to anything I altered
using System.Linq;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using UnityEngine;
public class PlayerAI : Agent
{
public Transform goal;
public GameObject box1;
public GameObject enemy;
public GameObject boundLimits;
public GameObject innerWall;
private float _disstanceToTheGround;
public bool _isGrounded;
private Rigidbody _rPlayer;
private Rigidbody _rGoal;
public float timer;
public int _playerLives = 1;
public float speed = 800f;
public int _jumpsLeft = 10;
private List<GameObject> _enemies = new List<GameObject>();
private List<Rigidbody> _rEnemies = new List<Rigidbody>();
private List<GameObject> _boxes1 = new List<GameObject>();
private List<GameObject> _boxes2 = new List<GameObject>();
public int roundtimer = 0;
private float _enemySpeed = 100f;
public int _enemyCount = 0;
public int _boxCount = 1;
private int[] _degrees = { 90, 180, 270 };
private int[] _forceIntervals = { 2, 8, 7 };
private int[] _forceDirections = { 0, 1, -1 };
private string[] tagsToSpawnOn = { "Ground1", "Ground2" };
private string _playerSpawnedOn;
private Bounds _areaBounds;
private float _xlimits;
private float _zlimits;
EnvironmentParameters resetParams;
public override void Initialize()
{
resetParams = Academy.Instance.EnvironmentParameters;
// CHANGED: A curriculum should drive only one parameter, here I've removed the scaling and we'll rely on just moving along the y-axis
// innerWall.transform.localScale = new Vector3(30f, resetParams.GetWithDefault("Inner_wall_height", 1f), 0.5f);
innerWall.transform.localPosition = new Vector3(0, resetParams.GetWithDefault("Inner_wall_height", 0.2f), 0f);
//transform.Rotate(0, _degrees[Random.Range(0, _degrees.Length)], 0);
_areaBounds = boundLimits.GetComponent<Collider>().bounds;
_disstanceToTheGround = GetComponent<Collider>().bounds.extents.y;
_xlimits = _areaBounds.extents.x;
_zlimits = _areaBounds.extents.z;
_rPlayer = GetComponent<Rigidbody>();
_rGoal = goal.GetComponent<Rigidbody>();
transform.localPosition = EmptyRandPosition(transform.name, _xlimits, _zlimits, 1, 2, tagsToSpawnOn);
goal.localPosition = EmptyRandPosition(goal.transform.name, _xlimits, _zlimits, 1, 2, tagsToSpawnOn);
//Debug.Log($"prize location: {prize.transform.localPosition}");
_rPlayer.velocity = Vector3.zero;
_rPlayer.angularVelocity = Vector3.zero;
_rGoal.velocity = Vector3.zero;
_rGoal.angularVelocity = Vector3.zero;
//Debug.Log($"{_xlimits} {_zlimits}");
for (int i = 0; i < _enemyCount; i++)
{
GameObject enemyClone = Instantiate(enemy, new Vector3(0, 0, 0), Quaternion.identity);
enemyClone.transform.parent = transform.parent;
enemyClone.transform.localPosition = EmptyRandPosition(enemyClone.transform.name, _xlimits, _zlimits, 1, 2, tagsToSpawnOn);
_enemies.Add(enemyClone);
_rEnemies.Add(enemyClone.GetComponent<Rigidbody>());
//Debug.Log($"Enemy created within {_xlimits} {_zlimits}.");
}
for (int i = 0; i < _boxCount; i++)
{
GameObject boxClone = Instantiate(box1, new Vector3(0, 0, 0), Quaternion.identity);
boxClone.transform.parent = transform.parent;
boxClone.transform.localPosition = EmptyRandPosition(boxClone.transform.name, _xlimits, _zlimits, 1, 2, new string[] { "Ground1" });
_boxes1.Add(boxClone);
}
for (int i = 0; i < _boxCount; i++)
{
GameObject boxClone = Instantiate(box1, new Vector3(0, 0, 0), Quaternion.identity);
boxClone.transform.parent = transform.parent;
boxClone.transform.localPosition = EmptyRandPosition(boxClone.transform.name, _xlimits, _zlimits, 1, 2, new string[] { "Ground2" });
_boxes2.Add(boxClone);
}
}
public override void OnEpisodeBegin()
{
// CHANGED: See above
// innerWall.transform.localScale = new Vector3(30f, resetParams.GetWithDefault("Inner_wall_height", 1f), 0.5f);
innerWall.transform.localPosition = new Vector3(0, resetParams.GetWithDefault("Inner_wall_height", 0.2f), 0f);
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(transform.localPosition);
// CHANGED: An agent cant learn about a punishment if they can't ever see why it's happening
sensor.AddObservation(_jumpsLeft);
sensor.AddObservation(_isGrounded);
// CHANGED: Just seeing if we can make this easier for the agent
// sensor.AddObservation(goal.localPosition);
sensor.AddObservation(this.transform.InverseTransformPoint(goal.transform.position));
sensor.AddObservation(_rPlayer.velocity);
}
public override void OnActionReceived(float[] vectorAction)
{
_isGrounded = Physics.Raycast(transform.position, Vector3.down, _disstanceToTheGround + 0.001f);
//Debug.Log($" is grounded {_isGrounded}");
//Debug.Log($"Jumps left: {_jumpsLeft}");
timer += Time.deltaTime;
// CHANGED: AddReward is correct here, we want all punishments to count
AddReward(-0.0005f);
MoveAgentVertical(vectorAction[0]);
MoveAgentHorizontal(vectorAction[1]);
// CHANGED: Jump & move work better as concurrent acts, games like mario allow you to change direction while midair
// because it's intuitive.
JumpAgent(vectorAction[2]);
if (transform.localPosition.y < 0)
{
AddReward(-1f);
EndEpisode();
ResetScene();
//Debug.Log("Player fell");
}
// CHANGED: 1000 seconds is really long, using max steps instead and lowered for my sanity
// if (timer > 1000)
// {
// SetReward(-1f);
// EndEpisode();
// ResetScene();
// //Debug.Log("Time ran out");
// }
// CHANGED:
if (StepCount == MaxStep - 1)
{
AddReward(-1f);
EndEpisode();
ResetScene();
}
if (_playerLives < 1)
{
EndEpisode();
ResetScene();
//Debug.Log("Player died!");
}
// if (_jumpsLeft < 1)
// {
// AddReward(-1f);
// EndEpisode();
// ResetScene();
// //Debug.Log("Player ran out of jumps");
// }
AddForceToEnemies();
}
public void OnCollisionEnter(Collision collision)
{
if (collision.transform.tag == "Enemy")
{
AddReward(-1f);
_playerLives -= 1;
//Debug.Log($"Lives at {_playerLives}");
}
if (collision.transform.tag == "Prize")
{
// CHANGED: Reward amount
AddReward(1f);
EndEpisode();
ResetScene();
Debug.Log("Prize touched!");
}
// if (tagsToSpawnOn.Contains(collision.transform.tag) && collision.transform.tag != _playerSpawnedOn)
// {
// Debug.Log($"Player changed ground from {_playerSpawnedOn} to {collision.transform.tag}");
// }
}
public void JumpAgent(float act)
{
switch (act)
{
// No jump
case 0:
break;
// Jump
case 1:
if (_jumpsLeft > 0 && _isGrounded)
{
_isGrounded = false;
// _jumpsLeft -= 1;
//Debug.Log("Player jumped");
// CHANGED: AddReward is correct here, we want all punishments to count
// AddReward(-0.001f);
_rPlayer.AddForce(new Vector3(0, 20f, 0) * Time.fixedDeltaTime * speed, ForceMode.Force);
}
else if (_jumpsLeft <= 0 && _isGrounded)
{
AddReward(-1f);
EndEpisode();
ResetScene();
}
break;
}
}
// CHANGED: Splitting out horizontal and lateral movement for better agent control (along with jump)
public void MoveAgentVertical(float act)
{
switch (act)
{
case 0:
break;
case 1:
_rPlayer.AddForce(Vector3.forward * Time.fixedDeltaTime * speed, ForceMode.Force);
break;
case 2:
_rPlayer.AddForce(Vector3.back * Time.fixedDeltaTime * speed, ForceMode.Force);
break;
}
}
public void MoveAgentHorizontal(float act)
{
switch (act)
{
case 0:
break;
case 1:
_rPlayer.AddForce(Vector3.right * Time.fixedDeltaTime * speed, ForceMode.Force);
break;
case 2:
_rPlayer.AddForce(Vector3.left * Time.fixedDeltaTime * speed, ForceMode.Force);
break;
}
}
// public void MoveAgent(float act)
// {
//
// Vector3 controlSignal = Vector3.zero;
// Vector3 rotateSignal = Vector3.zero;
// switch (act)
// {
// case 0:
// controlSignal.x = -1.5f;
// break;
// case 1:
// controlSignal.x = 0;
// break;
// case 2:
// controlSignal.x = 1.5f;
// break;
// case 3:
// controlSignal.z = -1.5f;
// break;
// case 4:
// controlSignal.z = 0f;
// break;
// case 5:
// controlSignal.z = 1.5f;
// break;
// // case 6:
// // if (_jumpsLeft > 0 && _isGrounded == true)
// // {
// // controlSignal.y = 10f;
// // _jumpsLeft -= 1;
// // //Debug.Log("Player jumped");
// // SetReward(-0.00001f);
// //
// // }
// // break;
//
// //case 6:
// // rotateSignal.y = -1;
// // break;
// //case 7:
// // rotateSignal.y = 0;
// // break;
// //case 8:
// // rotateSignal.y = 1;
// // break;
//
//
// }
// // if (_isGrounded == true)
// // {
// _rPlayer.AddForce(controlSignal * Time.fixedDeltaTime * speed, ForceMode.Force);
// //transform.Rotate(rotateSignal * Time.fixedDeltaTime * speed);
// // }
//
// }
public void AddForceToEnemies()
{
timer += Time.fixedDeltaTime;
roundtimer = Mathf.RoundToInt(timer);
//Debug.Log(roundtimer);
var randomInt = _forceIntervals[Random.Range(0, _forceIntervals.Length)];
//Debug.Log($"timer at {roundtimer}. enemy count: {_rEnemies.Count}. random force interval: {randomInt}");
foreach (Rigidbody body in _rEnemies)
{
if (roundtimer % randomInt == 0)
{
var DirToGO = new Vector3(_forceDirections[Random.Range(0, _forceDirections.Length)], 0,
_forceDirections[Random.Range(0, _forceDirections.Length)]);
body.AddForce(DirToGO * Time.fixedDeltaTime * _enemySpeed, ForceMode.Impulse);
}
}
}
public Vector3 EmptyRandPosition(string transformName, float xlimit, float zlimit, float customeY, float obstacleCheckRadius, string[] colsToSpawnOn)
{
Vector3 randPos;
int x = 0;
while (x < 500)
{
x += 1;
randPos = new Vector3(Random.Range(-xlimit, xlimit), customeY, Random.Range(-zlimit, zlimit));
//Debug.Log($"Random pos: {randPos.x} {randPos.y} {randPos.z}");
Collider[] colliders = Physics.OverlapSphere(randPos, obstacleCheckRadius);
//Debug.Log($"colliders length: {colliders.Length}");
if (colliders.Length == 1 && colsToSpawnOn.Contains(colliders[0].tag))
{
//Debug.Log($"spawned {transformName} on {colliders[0].tag} on pos: {randPos.x}, {randPos.y}, {randPos.z}");
if (transformName == "Player")
{
_playerSpawnedOn = colliders[0].tag;
//Debug.Log($"Player spawned on {_playerSpawnedOn}");
}
return randPos;
}
}
//Debug.Log($"could'nt find suitable location for {transformName} after {x} tries.");
return new Vector3(0, customeY, 0);
}
public void ResetScene()
{
timer = 0;
_playerLives = 1;
_jumpsLeft = 150;
_rPlayer.velocity = Vector3.zero;
_rPlayer.angularVelocity = Vector3.zero;
_rGoal.velocity = Vector3.zero;
_rGoal.angularVelocity = Vector3.zero;
//transform.Rotate(0, _degrees[Random.Range(0, _degrees.Length)], 0);
transform.localPosition = EmptyRandPosition(transform.name, _xlimits, _zlimits, 1, 2, tagsToSpawnOn);
goal.localPosition = EmptyRandPosition(goal.transform.name, _xlimits, _zlimits, 1, 2, tagsToSpawnOn);
//Debug.Log(prize.transform.localPosition);
foreach (Rigidbody rEnemy in _rEnemies)
{
rEnemy.velocity = Vector3.zero;
rEnemy.angularVelocity = Vector3.zero;
}
foreach (GameObject enemy in _enemies)
{
enemy.transform.localPosition = EmptyRandPosition(enemy.transform.name, _xlimits, _zlimits, 1, 2, tagsToSpawnOn);
}
foreach (GameObject box in _boxes1)
{
box.transform.localPosition = EmptyRandPosition(box.transform.name, _xlimits, _zlimits, 1, 2, new string[] { "Ground1" });
}
foreach (GameObject box in _boxes2)
{
box.transform.localPosition = EmptyRandPosition(box.transform.name, _xlimits, _zlimits, 1, 2, new string[] { "Ground2" });
}
// CHANGED: Redundant
// timer = 0;
// roundtimer = 0;
}
}