Struggling to teach an AI 3D pong

Hey everyone,

I created a simple game environment that’s just 6 walls, 2 paddles and a ball. The paddles can move along the Y and Z axis, with the X axis being the space between the paddles.

I’ve been trying to use PPO but haven’t had much success, My goal is to get the Agent to have an infinite rally, but even after 20 million steps it sruggles to get more than 1 hit per episode.

I’ve tried playing around with the hyperparameters and the reward system, but the results aren’t getting any better. Is there something that I’m missing here?

[code=CSharp
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class Paddle : Agent
{
private Vector3 initialPosition;
private Vector3 movement;
private Rigidbody paddleRigidbody;

public BallMovement ball;
public float speed = 20f;

// Start is called before the first frame update
void Start()
{
initialPosition = transform.position;
paddleRigidbody = GetComponent();

}

public override void OnEpisodeBegin()
{
ResetPaddles();
}

public override void CollectObservations(VectorSensor sensor)
{
// Normalize the paddle’s velocity based on the actual limits
Vector3 toBall = new Vector3((ball.ballRigidbody.transform.position.x - transform.position.x),
(ball.ballRigidbody.transform.position.y - transform.position.y),
(ball.ballRigidbody.transform.position.z - transform.position.z));

sensor.AddObservation(toBall.normalized);
sensor.AddObservation(paddleRigidbody.velocity.normalized);
sensor.AddObservation(transform.position.normalized);

sensor.AddObservation(ball.transform.position.normalized);
sensor.AddObservation(ball.ballRigidbody.velocity.normalized);
}

public override void OnActionReceived(ActionBuffers actions)
{
// Extract the actions from the ActionBuffers object
float moveVertical = actions.ContinuousActions[0];
float moveHorizontal = actions.ContinuousActions[1];

// Apply actions to the paddle
movement = new Vector3(0f, moveVertical, moveHorizontal);
paddleRigidbody.velocity = movement * speed;

//Add a small reward when the ball is in play
AddReward(0.01f);
}

public override void Heuristic(in ActionBuffers actionsOut)
{
// Use the existing input system for human-controlled paddle
float moveVertical = 0f;
float moveHorizontal = 0f;

if (Input.GetKey(KeyCode.UpArrow))
{
moveVertical = 1f;
}
if (Input.GetKey(KeyCode.DownArrow))
{
moveVertical = -1f;
}

if (Input.GetKey(KeyCode.LeftArrow))
{
moveHorizontal = -1f;
}

if (Input.GetKey(KeyCode.RightArrow))
{
moveHorizontal = 1f;
}

actionsOut.ContinuousActions.Array[0] = moveVertical;
actionsOut.ContinuousActions.Array[1] = moveHorizontal;
}

public void ResetPaddles()
{
// Reset paddles to initial position
transform.position = initialPosition;
}

//Reward Functions
void OnCollisionEnter(Collision collision)
{
if (collision.gameObject.tag == “ball”)
{
AddReward(1f);
}
}
}

[/code]

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BallMovement : MonoBehaviour
{
    public Rigidbody ballRigidbody;
    public Vector3 direction;
    public Paddle paddle1;
    public Paddle paddle2;

    private Vector3 initialPosition;
    private Paddle lastHit;
    private float speed;
    private int leftOrRight;

    void Start()
    {
        ballRigidbody = GetComponent<Rigidbody>();
        initialPosition = ballRigidbody.transform.position;
        speed = 20f;

        // Randomly choose left or right
        if (Random.Range(0f, 1f) < 0.5f)
        {
            leftOrRight = -1;
        }
        else
        {
            leftOrRight = 1;
        }

        // Generate random direction
        direction = new Vector3(leftOrRight, Random.Range(-1f, 1f), Random.Range(-1f, 1f)).normalized;

        // Apply force in random direction
        ballRigidbody.velocity = direction.normalized * speed;
    }

    void OnCollisionEnter(Collision collision)
    {
        // If ball hits wall, reset game
        if (collision.gameObject.tag == "backWall")
        {
            if (lastHit == paddle1)
            {
                paddle1.AddReward(-1f);
            }
            else if (lastHit == paddle2)
            {
                paddle2.AddReward(-1f);
            }
            else
            {
                paddle1.AddReward(-1f);
                paddle2.AddReward(-1f);
            }
            ResetGame();
        }
        else if (collision.gameObject.tag == "paddle")
        {
            lastHit = collision.gameObject.GetComponent<Paddle>();

            // Calculate hit factor for y and z axes
            float y = yHit(transform.position, collision.transform.position, collision.collider.bounds.size.y);
            float z = zHit(transform.position, collision.transform.position, collision.collider.bounds.size.z);

            //x value depends on paddle hit
            float x = collision.gameObject.name == "Paddle1" ? -1 : 1;

            // Calculate direction, make length=1 via .normalized
            direction = new Vector3(x, y, z).normalized;

            // Set Velocity with direction * speed
            GetComponent<Rigidbody>().velocity = direction * speed;
        }
    }

    float yHit(Vector3 ballPos, Vector3 racketPos, float racketHeight) {
        return (ballPos.y - racketPos.y) / (racketHeight / 2f);
    }

    float zHit(Vector3 ballPos, Vector3 racketPos, float racketWidth) {
        return (ballPos.z - racketPos.z) / (racketWidth / 2f);
    }

    void ResetGame()
    {
        // Reset ball to initial position
        paddle1.ResetPaddles();
        paddle2.ResetPaddles();
        paddle1.EndEpisode();
        paddle2.EndEpisode();
        ballRigidbody.transform.position = initialPosition;
        Start();
    }
}
behaviors:
  My Behavior:
    trainer_type: ppo
    hyperparameters:
      batch_size: 4096
      buffer_size: 409600
      learning_rate: 0.0002
      beta: 0.003
      epsilon: 0.15
      lambd: 0.93
      num_epoch: 6
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 512
      num_layers: 3
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.99 
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 5000000
    time_horizon: 100000
    summary_freq: 20000

0

I appreciate any help in advance

Thanks!

Here are the parameters I have set in Unity.