Cannot figure out POCA training setup

Hello, I am creating a FlappyBird project, where the bird plays by itself using AI. I was successful with PPO (Proximal Policy Optimization) training, but could not figure out how to do it for POCA (Parameter Optimization and Continuous Allocation).

In the game, there are 4 exactly same BirdAgent instances in the same environment and I want the episode to end when all of those bird fail. So, when any of the birds fail, BirdAgent script calls BirdFallen() method of the BirdManager script. You can see the scripts attached.

It counts the fallen birds and if all of them fall down, the episode is being ended by agentGroup.EndGroupEpisode() method. However, the episode doesn’t start again after ending once.

agentGroup is of type SimpleMultiAgentGroup class.

I would also appreciate if you could provide me with good documentation & resources, as I think the official documentation is not enough, specifically for SimpleMultiAgentGroup class and its usage.

using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.VisualScripting;
using UnityEditor;
using UnityEngine;

public class BirdAgent : Agent
{
    private float timer = 0f;
    private bool isInitializing = false;

    [SerializeField] private Bird bird;
    [SerializeField] private Rigidbody2D rb2D;

    private void ApplyGroupPerformanceReward()
    {
        if (timer > 5f)
        {
            BirdManager.Instance.AgentGroup.AddGroupReward(1.0f);
        }
    }

    private void Update()
    {
        timer += Time.deltaTime;

        if (timer > 1f)
        {
            AddReward(0.1f);
            timer = 0f;
        }

        ApplyGroupPerformanceReward();
    }

    public bool IsInitializing() => isInitializing;

    public override void OnEpisodeBegin()
    {
        BirdManager.Instance.SetEpisodeEnded();
        BirdManager.Instance.RegisterBird(this);
        BirdManager.Instance.SetActive(bird.gameObject);
        isInitializing = true;
        timer = 0f;
        bird.Reset();
        PipeSpawner.DestroyAllPipes();
        isInitializing = false;
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(transform.position.y);
        sensor.AddObservation(rb2D.velocity.y);

        Pipe nextPipe = PipeSpawner.GetNextPipe();

        if (nextPipe != null)
        {
            sensor.AddObservation(nextPipe.transform.position.x - transform.position.x);

            float gapCenterY = (nextPipe.transform.position.y + nextPipe.LowerPipe.transform.position.y) / 2;
            sensor.AddObservation(gapCenterY - transform.position.y);
        }

        float distanceToGround = transform.position.y;
        sensor.AddObservation(distanceToGround);

        float distanceToCeiling = Camera.main.orthographicSize * 2 - transform.position.y;
        sensor.AddObservation(distanceToCeiling);
    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        if (actions.DiscreteActions[0] == 1)
        {
            bird.Jump();

            Pipe nextPipe = PipeSpawner.GetNextPipe();
            if (nextPipe != null)
            {
                if (transform.position.y < nextPipe.CheckpointUpperY - 0.5f && transform.position.y > nextPipe.CheckpointLowerY + 0.5f)
                {
                    AddReward(0.8f);
                }
            }

            if (transform.position.y > 5f)
            {
                AddReward(-0.5f);
                BirdManager.Instance.BirdFailed(this);
            }
        }

        AddReward(0.05f);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
        discreteActions[0] = Input.GetKeyDown(KeyCode.Space) ? 1 : 0;
    }

    private void OnTriggerEnter2D(Collider2D collision)
    {
        if (collision.CompareTag("Checkpoint"))
        {
            if (PipeSpawner.Pipes.Count == 0)
            {
                PipeSpawner.Pipes.Dequeue();
            }

            AddReward(1.0f);
        }

        if (collision.CompareTag("Ground"))
        {
            AddReward(-1.0f);
            BirdManager.Instance.BirdFailed(this);
        }

        if (collision.TryGetComponent(out Pipe pipe))
        {
            AddReward(-1.0f);
            BirdManager.Instance.BirdFailed(this);
        }
    }

    private void OnCollisionEnter2D(Collision2D collision)
    {
        if (collision.gameObject.CompareTag("Ground") || collision.gameObject.CompareTag("Pipe"))
        {
            AddReward(-1.0f);
            BirdManager.Instance.BirdFailed(this);
        }
    }
}

using UnityEngine;
using System.Linq;
using System.Collections.Generic;
using Unity.MLAgents;

public class BirdManager : MonoBehaviour
{
    public static BirdManager Instance;
    private List<BirdAgent> allBirds = new List<BirdAgent>();
    private int birdsFallen = 0;

    private SimpleMultiAgentGroup agentGroup;
    public SimpleMultiAgentGroup AgentGroup { get => agentGroup; }

    private void Awake()
    {
        if (Instance == null)
        {
            Instance = this;
        }
        else
        {
            Destroy(gameObject);
        }

        agentGroup = new SimpleMultiAgentGroup();
    }

    public void RegisterBird(BirdAgent bird)
    {
        if (!allBirds.Contains(bird))
        {
            allBirds.Add(bird);
        }
        agentGroup.RegisterAgent(bird);
    }

    private bool episodeEnded = false;

    public void BirdFailed(BirdAgent fallenBird)
    {
        birdsFallen++;
        fallenBird.gameObject.SetActive(false);

        if (birdsFallen >= allBirds.Count && !episodeEnded)
        {
            episodeEnded = true;
            agentGroup.EndGroupEpisode();
            birdsFallen = 0;
        }
    }

    public void SetActive(GameObject obj)
    {
        obj.SetActive(true);
    }

    private int createdBird = 0;

    public void SetEpisodeEnded()
    {
        episodeEnded = false;
    }
}