I’m using ml-agents 1.0.8 and I made a simple game for a shooter everything works fine when I control it but when I begin training its spawns too many enemies every time it gets eliminated its multiplies
here is the code:
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
public class PlayerController : Agent
{
public float speed = 10.0f;
public GameObject bulletPrefab;
public Transform bulletSpawn;
public int ResetNum=3;
public bool isReset=false;
public int enemiesNum =3;
public GameObject TYameObject;
public GameObject enemy;
[SerializeField] private float platformWidth = 10f;
[SerializeField] private float platformLength = 10f;
[SerializeField] private float spawnHeight = 0.5f;
public float bulletSpeed = 20.0f;
public float moveX = 0.0f;
public float moveZ = 0.0f;
public float rotation = 0.0f;
public override void Initialize()
{
SpawnEnemies();
SpawnEnemies();
SpawnEnemies();
}
public void SpawnEnemies()
{
float minDistanceFromPlayer = 2f; // adjust this value to set the minimum distance from player
Vector3 randomOffset = new Vector3(
Random.Range(-platformWidth / 2f, platformWidth / 2f),
spawnHeight,
Random.Range(-platformLength / 2f, platformLength / 2f));
// Ensure that the randomOffset is at least minDistanceFromPlayer away from the player
while (Vector3.Distance(TYameObject.transform.position + randomOffset, TYameObject.transform.position) < minDistanceFromPlayer)
{
randomOffset = new Vector3(
Random.Range(-platformWidth / 2f, platformWidth / 2f),
spawnHeight,
Random.Range(-platformLength / 2f, platformLength / 2f));
}
Vector3 randomPosition = TYameObject.transform.position + randomOffset;
GameObject currentEnemy = Instantiate(enemy, randomPosition, Quaternion.identity);
}
public override void OnActionReceived(float[] vectorAction)
{
if(Mathf.FloorToInt(vectorAction[0])==1)
{
WalkForward();
}else if(Mathf.FloorToInt(vectorAction[0])==2)
{
WalkBackward();
}
if(Mathf.FloorToInt(vectorAction[1])==1)
{
WalkLeft();
}else if(Mathf.FloorToInt(vectorAction[1])==2)
{
WalkRight();
}
if(Mathf.FloorToInt(vectorAction[2])==1)
{
RotateLeft();
}else if(Mathf.FloorToInt(vectorAction[2])==2)
{
RotateRight();
}
if(Mathf.FloorToInt(vectorAction[3])==1)
{
Shoot();
}
}
public override void OnEpisodeBegin()
{
Reset();
}
public void Reset()
{
}
public override void Heuristic(float[] actionsOut)
{
actionsOut[0]=0;
actionsOut[1]=0;
actionsOut[2]=0;
actionsOut[3]=0;
if (Input.GetKey(KeyCode.LeftArrow))
{
actionsOut[2]=1;
}
if (Input.GetKey(KeyCode.RightArrow))
{
actionsOut[2]=2;
}
if (Input.GetKey(KeyCode.W))
{
actionsOut[0]=1;
}
if (Input.GetKey(KeyCode.S))
{
actionsOut[0]=2;
}
if (Input.GetKey(KeyCode.A))
{
actionsOut[1]=1;
}
if (Input.GetKey(KeyCode.D))
{
actionsOut[1]=2;
}
if (Input.GetMouseButtonDown(0))
{
actionsOut[3]=1;
}
}
public void FixedUpdate() {
RequestDecision();
}
public void Shoot()
{
AddReward(-0.1f);
GameObject bullet = Instantiate(bulletPrefab, bulletSpawn.position, bulletSpawn.rotation);
bullet.GetComponent<Rigidbody>().velocity = transform.forward * bulletSpeed;
Destroy(bullet, 1);
}
private void Update()
{
// Check if there are any game objects with the "MyTag" tag
if (transform.position.y < -2f){
transform. position = new Vector3(TYameObject.transform.position.x, 0.5f, TYameObject.transform.position.z);
isReset=true;
SpawnEnemies();
SpawnEnemies();
SpawnEnemies();
AddReward(-1f);
}
// Rotation
// Movement
Vector3 movement = new Vector3(moveX, 0.0f, moveZ).normalized * speed * Time.deltaTime;
transform.Translate(movement, Space.Self);
moveX = 0.0f;
moveZ = 0.0f;
}
public void RotateLeft()
{
rotation = 0.0f;
rotation -= 1.0f;
transform.Rotate(Vector3.up, rotation * Time.deltaTime * 100.0f);
}
public void RotateRight()
{
rotation = 0.0f;
rotation += 1.0f;
transform.Rotate(Vector3.up, rotation * Time.deltaTime * 100.0f);
}
public void WalkForward()
{
Vector3 direction = transform.forward;
transform.Translate(direction * speed * Time.deltaTime);
if(enemiesNum<=0)
{
AddReward(1f);
enemiesNum=3;
transform. position = new Vector3(TYameObject.transform.position.x, 0.5f, TYameObject.transform.position.z);
isReset=true;
SpawnEnemies();
SpawnEnemies();
SpawnEnemies();
}
}
public void WalkBackward()
{
Vector3 direction = -transform.forward;
transform.Translate(direction * speed * Time.deltaTime);
}
public void WalkLeft()
{
Vector3 direction = -transform.right;
transform.Translate(direction * speed * Time.deltaTime);
}
public void WalkRight()
{
Vector3 direction = transform.right;
transform.Translate(direction * speed * Time.deltaTime);
}
public void OnCollisionEnter(Collision other)
{
if(other.gameObject.tag=="Enemy")
{
transform. position = new Vector3(TYameObject.transform.position.x, 0.5f, TYameObject.transform.position.z);
isReset=true;
SpawnEnemies();
SpawnEnemies();
SpawnEnemies();
AddReward(-1f);
EndEpisode();
}else if (other.gameObject.tag == "Wall")
{
transform. position = new Vector3(TYameObject.transform.position.x, 0.5f, TYameObject.transform.position.z);
isReset=true;
SpawnEnemies();
SpawnEnemies();
SpawnEnemies();
AddReward(-1.0f);
Debug.Log(GetCumulativeReward());
EndEpisode();
}
}
}
enemy ai:
using UnityEngine;
public class EnemyAI : MonoBehaviour
{
public GameObject player;
public float speed = 5f;
public bool isTouch = false;
void Awake()
{
player = FindClosestPlayerWithTag();
}
private GameObject FindClosestPlayerWithTag()
{
GameObject[] players = GameObject.FindGameObjectsWithTag("Player");
GameObject closest = null;
float distance = Mathf.Infinity;
Vector3 position = transform.position;
foreach (GameObject player in players)
{
Vector3 diff = player.transform.position - position;
float curDistance = diff.sqrMagnitude;
if (curDistance < distance)
{
closest = player;
distance = curDistance;
}
}
return closest;
}
public void OnCollisionStay(Collision other)
{
if (other.gameObject == player)
{
isTouch = true;
Debug.Log(player.GetComponent<PlayerController>().GetCumulativeReward());
Destroy(gameObject);
}
}
public void OnCollisionEnter(Collision other)
{
if (other.gameObject.tag == "Bullet")
{
player.GetComponent<PlayerController>().ResetNum--;
player.GetComponent<PlayerController>().AddReward(0.5f);
Debug.Log(player.GetComponent<PlayerController>().GetCumulativeReward());
if(player.GetComponent<PlayerController>().enemiesNum>0)
{
player.GetComponent<PlayerController>().enemiesNum--;
}
Destroy(gameObject);
Destroy(other.gameObject);
}
}
public void OnCollisionExit(Collision other)
{
isTouch = false;
}
void Update()
{
if(player.GetComponent<PlayerController>().isReset)
{
if(player.GetComponent<PlayerController>().ResetNum>0){
Debug.Log("enemyprob prob");
player.GetComponent<PlayerController>().ResetNum--;
Destroy(gameObject);
}else{
player.GetComponent<PlayerController>().isReset=false;
player.GetComponent<PlayerController>().ResetNum=3;
}
}
if (!isTouch)
{
Vector3 direction = player.transform.position - transform.position;
direction.y = 0f;
direction.Normalize();
transform.Translate(direction * speed * Time.deltaTime, Space.World);
}
}
}
the trainer config:
behaviors:
Shooter:
trainer_type: ppo
max_steps: 5.0e7
time_horizon: 64
summary_freq: 10000
hyperparameters:
batch_size: 256
beta: 0.005
buffer_size: 2048
epsilon: 0.2
lambd: 0.95
learning_rate: 0.0004
learning_rate_schedule: linear
num_epoch: 3
network_settings:
vis_encode_type: simple
num_layers: 2
normalize: false
hidden_units: 64
memory:
sequence_length: 64
memory_size: 640
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99