ML-agents and navmesh

Hi, I am trying to create an agent that learns to a simple RTS game that i created. the problem is when i look at all the examples of ML-agents use a rigidbody component for handling movement and collisions, where as my implementation i use a nav mesh agent for pathfinding and moving across the map and a capsule collider for collision detection.

the 4 scripts ive added below show what i am trying to make into ML-agents, the 1st script is my unit script where i set up all the behaviours using states that a unit can preform. and then the second, 3rd and 4th scripts are how a player controls the units and how they can create new units.

unitScript ::

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.AI;
using UnityEngine.Events;
public enum UnitState
{
Idle,
Move,
MoveToResource,
Gather,
MoveToEnemy,
Attack
}
public class Unit : MonoBehaviour
{
[Header(“Stats”)]
public UnitState state;
public int curHp;
public int maxHp;
public int minAttackDamage;
public int maxAttackDamage;
public float attackRate;
private float lastAttackTime;
public float attackDistance;
public float pathUpdateRate = 1.0f;
private float lastPathUpdateTime;
public int gatherAmount;
public float gatherRate;
private float lastGatherTime;
public ResourceSource curResourceSource;
private Unit curEnemyTarget;
[Header(“Components”)]
public GameObject selectionVisual;
private NavMeshAgent navAgent;
public UnitHealthBar healthBar;
public Player player;
// events
[System.Serializable]
public class StateChangeEvent : UnityEvent { }
public StateChangeEvent onStateChange;
void Start ()
{
// get the components
navAgent = GetComponent();
SetState(UnitState.Idle);
}
void SetState (UnitState toState)
{
state = toState;
// calling the event
if(onStateChange != null)
onStateChange.Invoke(state);
if(toState == UnitState.Idle)
{
navAgent.isStopped = true;
navAgent.ResetPath();
}
}
void Update ()
{
switch(state)
{
case UnitState.Move:
{
MoveUpdate();
break;
}
case UnitState.MoveToResource:
{
MoveToResourceUpdate();
break;
}
case UnitState.Gather:
{
GatherUpdate();
break;
}
case UnitState.MoveToEnemy:
{
MoveToEnemyUpdate();
break;
}
case UnitState.Attack:
{
AttackUpdate();
break;
}
}
}
// called every frame the ‘Move’ state is active
void MoveUpdate ()
{
if(Vector3.Distance(transform.position, navAgent.destination) == 0.0f)
SetState(UnitState.Idle);
}
// called every frame the ‘MoveToResource’ state is active
void MoveToResourceUpdate ()
{
if(curResourceSource == null)
{
SetState(UnitState.Idle);
return;
}
if(Vector3.Distance(transform.position, navAgent.destination) == 0.0f)
SetState(UnitState.Gather);
}
// called every frame the ‘Gather’ state is active
void GatherUpdate ()
{
if(curResourceSource == null)
{
SetState(UnitState.Idle);
return;
}
LookAt(curResourceSource.transform.position);
if(Time.time - lastGatherTime > gatherRate)
{
lastGatherTime = Time.time;
curResourceSource.GatherResource(gatherAmount, player);
}
}
// called every frame the ‘MoveToEnemy’ state is active
void MoveToEnemyUpdate ()
{
// if our target is dead, go idle
if(curEnemyTarget == null)
{
SetState(UnitState.Idle);
return;
}
if(Time.time - lastPathUpdateTime > pathUpdateRate)
{
lastPathUpdateTime = Time.time;
navAgent.isStopped = false;
navAgent.SetDestination(curEnemyTarget.transform.position);
}
if(Vector3.Distance(transform.position, curEnemyTarget.transform.position) <= attackDistance)
SetState(UnitState.Attack);
}
// called every frame the ‘Attack’ state is active
void AttackUpdate ()
{
// if our target is dead, go idle
if(curEnemyTarget == null)
{
SetState(UnitState.Idle);
return;
}
// if we’re still moving, stop
if(!navAgent.isStopped)
navAgent.isStopped = true;
// attack every ‘attackRate’ seconds
if(Time.time - lastAttackTime > attackRate)
{
lastAttackTime = Time.time;
curEnemyTarget.TakeDamage(Random.Range(minAttackDamage, maxAttackDamage + 1));
}
// look at the enemy
LookAt(curEnemyTarget.transform.position);
// if we’re too far away, move towards the enemy
if(Vector3.Distance(transform.position, curEnemyTarget.transform.position) > attackDistance)
SetState(UnitState.MoveToEnemy);
}
// called when an enemy unit attacks us
public void TakeDamage (int damage)
{
curHp -= damage;
if(curHp <= 0)
Die();
healthBar.UpdateHealthBar(curHp, maxHp);
}
// called when our health reaches 0
void Die ()
{
player.units.Remove(this);
GameManager.instance.UnitDeathCheck();
Destroy(gameObject);
}
// moves the unit to a specific position
public void MoveToPosition (Vector3 pos)
{
SetState(UnitState.Move);
navAgent.isStopped = false;
navAgent.SetDestination(pos);
}
// move to a resource and begin to gather it
public void GatherResource (ResourceSource resource, Vector3 pos)
{
curResourceSource = resource;
SetState(UnitState.MoveToResource);
navAgent.isStopped = false;
navAgent.SetDestination(pos);
}
// move to an enemy unit and attack them
public void AttackUnit (Unit target)
{
curEnemyTarget = target;
SetState(UnitState.MoveToEnemy);
}
// toggles the selection ring around our feet
public void ToggleSelectionVisual (bool selected)
{
if(selectionVisual != null)
selectionVisual.SetActive(selected);
}
// rotate to face the given position
void LookAt (Vector3 pos)
{
Vector3 dir = (pos - transform.position).normalized;
float angle = Mathf.Atan2(dir.x, dir.z) * Mathf.Rad2Deg;
transform.rotation = Quaternion.Euler(0, angle, 0);
}
}

playerScript ::

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Events;
public class Player : MonoBehaviour
{
public bool isMe;
[Header(“Units”)]
public List units = new List();

[Header(“Resources”)]
public int food;
[Header(“Components”)]
#region unit components
[Header(“Gatherer Unit”)]
public GameObject unitPrefab;
public Transform unitSpawnPos;
[Header(“Soldier Unit”)]
public GameObject unit2Prefab;
public Transform unit2SpawnPos;
public GameObject[ ] soliderArray;
[Header(“Commander Unit”)]
public GameObject unit3Prefab;
public Transform unit3SpawnPos;
public GameObject[ ] commanderArray;
#endregion
// events
[System.Serializable]
public class UnitCreatedEvent : UnityEvent { }
public UnitCreatedEvent onUnitCreated;
#region unit costs
public readonly int unitCost = 50;
public readonly int unit2Cost = 25; //only lower cause the system developed takes away money * the number of units created
public readonly int unit3Cost = 40;
#endregion
public static Player me;
void Awake ()
{
if(isMe)
me = this;
}
void Start ()
{
if(isMe)
{
GameUI.instance.UpdateUnitCountText(units.Count);
GameUI.instance.UpdateFoodText(food);
GameUI.instance.UpdateSoldierCountText(units.Count);
GameUI.instance.UpdateCommanderCountText(units.Count);
CameraController.instance.FocusOnPosition(unitSpawnPos.position);
}
food += unitCost;
CreateNewUnit();
}
// called when a unit gathers a certain resource
public void GainResource (ResourceType resourceType, int amount)
{
switch(resourceType)
{
case ResourceType.Food:
{
food += amount;
if(isMe)
GameUI.instance.UpdateFoodText(food);
break;
}
}
}
// debug to see if a unit spawns or not when a key is pressed
/void Update()
{
if (Input.GetKeyDown(KeyCode.N))
CreateNewUnit2();
}
/
// creates a new unit for the player
#region create units
#region Gatherer unit create
public void CreateNewUnit ()
{
if(food - unitCost < 0)
return;
GameObject unitObj = Instantiate(unitPrefab, unitSpawnPos.position, Quaternion.identity, transform);
Unit unit = unitObj.GetComponent();
units.Add(unit);
unit.player = this;
food -= unitCost;
if(onUnitCreated != null)
onUnitCreated.Invoke(unit);
if(isMe)
{
GameUI.instance.UpdateUnitCountText(units.Count);
GameUI.instance.UpdateFoodText(food);
}
}
#endregion
#region Solider unit create
public void CreateNewUnit2()
{
if (food - unit2Cost < 0)
return;
soliderArray = new GameObject[4]; // creates 4 of the 1 unit
for (int i = 0; i < soliderArray.Length; i++)
{
GameObject unitObj2 = Instantiate(unit2Prefab, unit2SpawnPos.position, Quaternion.identity, transform);
Unit unit = unitObj2.GetComponent();
units.Add(unit);
unit.player = this;
food -= unit2Cost;
if (onUnitCreated != null)
onUnitCreated.Invoke(unit);
if (isMe)
{
GameUI.instance.UpdateUnitCountText(units.Count);
GameUI.instance.UpdateFoodText(food);
}

}
}

#endregion
#region Commander unit create
public void CreateNewUnit3()
{
if (food - unit3Cost < 0)
return;
commanderArray = new GameObject[5]; //creates 6 of the 1 unit
for (int i = 0; i < commanderArray.Length; i++)
{
GameObject unitObj3 = Instantiate(unit3Prefab, unit3SpawnPos.position, Quaternion.identity, transform);
Unit unit = unitObj3.GetComponent();

units.Add(unit);
unit.player = this;
food -= unit3Cost;
if (onUnitCreated != null)
onUnitCreated.Invoke(unit);
if (isMe)
{
GameUI.instance.UpdateUnitCountText(units.Count);
GameUI.instance.UpdateFoodText(food);
}
}
}
#endregion
#endregion
// is this my unit?
public bool IsMyUnit (Unit unit)
{
return units.Contains(unit);
}
}

UnitCommanderScript ::

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class UnitCommander : MonoBehaviour
{
public GameObject selectionMarkerPrefab;
public LayerMask layerMask;
// components
private UnitSelection unitSelection;
private Camera cam;
void Awake ()
{
// get the components
unitSelection = GetComponent();
cam = Camera.main;
}
void Update ()
{
// did we press down our right mouse button and do we have units selected?
if(Input.GetMouseButtonDown(1) && unitSelection.HasUnitsSelected())
{
// shoot a raycast from our mouse, to see what we hit
Ray ray = cam.ScreenPointToRay(Input.mousePosition);
RaycastHit hit;
// cache the selected units in an array
Unit[ ] selectedUnits = unitSelection.GetSelectedUnits();
// shoot the raycast
if(Physics.Raycast(ray, out hit, 100, layerMask))
{
unitSelection.RemoveNullUnitsFromSelection();
// are we clicking on the ground?
if(hit.collider.CompareTag(“Ground”))
{
UnitsMoveToPosition(hit.point, selectedUnits);
CreateSelectionMarker(hit.point, false);
}
// did we click on a resource?
else if(hit.collider.CompareTag(“Resource”))
{
UnitsGatherResource(hit.collider.GetComponent(), selectedUnits);
CreateSelectionMarker(hit.collider.transform.position, true);
}
// did we click on an enemy?
else if(hit.collider.CompareTag(“Unit”))
{
Unit enemy = hit.collider.gameObject.GetComponent();
if(!Player.me.IsMyUnit(enemy))
{
UnitsAttackEnemy(enemy, selectedUnits);
CreateSelectionMarker(enemy.transform.position, false);
}
}
}
}
}
// called when we command units to move somewhere
void UnitsMoveToPosition (Vector3 movePos, Unit[ ] units)
{
Vector3[ ] destinations = UnitMover.GetUnitGroupDestinations(movePos, units.Length, 2);
for(int x = 0; x < units.Length; x++)
{
units[×].MoveToPosition(destinations[×]);
}
}
// called when we command units to gather a resource
void UnitsGatherResource (ResourceSource resource, Unit[ ] units)
{
// are just selecting 1 unit?
if(units.Length == 1)
{
units[0].GatherResource(resource, UnitMover.GetUnitDestinationAroundResource(resource.transform.position));
}
// otherwise, calculate the unit group formation
else
{
Vector3[ ] destinations = UnitMover.GetUnitGroupDestinationsAroundResource(resource.transform.position, units.Length);
for(int x = 0; x < units.Length; x++)
{
units[×].GatherResource(resource, destinations[×]);
}
}
}
// called when we command units to attack an enemy
void UnitsAttackEnemy (Unit target, Unit[ ] units)
{
for(int x = 0; x < units.Length; x++)
units[×].AttackUnit(target);
}
// creates a new selection marker visual at the given position
void CreateSelectionMarker (Vector3 pos, bool large)
{
GameObject marker = Instantiate(selectionMarkerPrefab, new Vector3(pos.x, 0.01f, pos.z), Quaternion.identity);
if(large)
marker.transform.localScale = Vector3.one * 3;
}
}

unitMoverSctipt ::

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class UnitMover : MonoBehaviour
{
// calculates a unit formation around a given destination
public static Vector3[ ] GetUnitGroupDestinations (Vector3 moveToPos, int numUnits, float unitGap)
{
// vector3 array for final destinations
Vector3[ ] destinations = new Vector3[numUnits];
// calculate the rows and columns
int rows = Mathf.RoundToInt(Mathf.Sqrt(numUnits));
int cols = Mathf.CeilToInt((float)numUnits / (float)rows);
// we need to know the current row and column we’re calculating
int curRow = 0;
int curCol = 0;
float width = ((float)rows - 1) * unitGap;
float length = ((float)cols - 1) * unitGap;
for(int x = 0; x < numUnits; x++)
{
destinations[×] = moveToPos + (new Vector3(curRow, 0, curCol) * unitGap) - new Vector3(length / 2, 0, width / 2);
curCol++;
if(curCol == rows)
{
curCol = 0;
curRow++;
}
}
return destinations;
}
// returns an array of positions evenly spaced around a resource
public static Vector3[ ] GetUnitGroupDestinationsAroundResource (Vector3 resourcePos, int unitsNum)
{
Vector3[ ] destinations = new Vector3[unitsNum];
float unitDistanceGap = 360.0f / (float)unitsNum;
for(int x = 0; x < unitsNum; x++)
{
float angle = unitDistanceGap * x;
Vector3 dir = new Vector3(Mathf.Sin(angle * Mathf.Deg2Rad), 0, Mathf.Cos(angle * Mathf.Deg2Rad));
destinations[×] = resourcePos + dir;
}
return destinations;
}
public static Vector3 GetUnitDestinationAroundResource (Vector3 resourcePos)
{
float angle = Random.Range(0, 360);
Vector3 dir = new Vector3(Mathf.Sin(angle * Mathf.Deg2Rad), 0, Mathf.Cos(angle * Mathf.Deg2Rad));
return resourcePos + dir;
}
}

so any help on how to make these scripts controlled by an ML-agent would be greatly appreciated. not looking for someone to do the work for me, just a nudge in the right direction with a few examples. i have thought about using ray perception sensor for the unit script but that seems to be dependent on using a rigidbody component, and they for the player script using a camera sensor to recognise how much resources they have and to then create a new unit. like what do i put in the key methods used in ML-agents?

public override void Initialize()
{

}
public override void OnEpisodeBegin()
{

}
public override void OnActionReceived(float[ ] vectorAction)
{

}
public override void Heuristic(float[ ] actionsOut)
{

}

any help is greatly appreciated.

Hey @MurdoMacIver , check out our example environments in the repo for some examples on what to put in those methods. From what I can gather, navigation tasks like Hallway, FoodCollector and Pyramids might be most similar to your task.