So I have the following code, I have a Cube and a floor with rigid Bodies. As far as I learned, probably incorrectly, this code should allow training the cube so it moves, but it doesn’t fall from the platform.
But I have to problems:
-
Heuristics.
When I set the engine to Heuristics I would expect to be able / test the model by moving the cube using the arrow keys (this doesn’t work) -
Inference
I was hoping the engine would create values randomly and fail from time to time, by falling out of the platform, which is not working.
Note: Heuristics and Action Received Logs never appear. Although episodes can be created when increasing max steps.
Any suggestion? What am I missing?
Thanks!!
Code:
using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class PlayerAgent : Agent
{
Rigidbody rBody;
Vector3 startingPosition;
public override void Initialize()
{
Debug.Log("Initialize");
rBody = GetComponent<Rigidbody>();
startingPosition = transform.position;
}
public override void OnActionReceived(ActionBuffers actions)
{
Debug.Log("action Received");
float moveX = actions.ContinuousActions[0];
float moveZ = actions.ContinuousActions[1];
float moveSpeed = 1f;
transform.position += new Vector3(moveX, 0, moveZ) * Time.deltaTime * moveSpeed;
if (transform.position.y > 0 && transform.position.y < 10)
{
Debug.Log("1");
SetReward(0.1f);
EndEpisode();
}
else
{
Debug.Log("2");
SetReward(-1f);
EndEpisode();
}
}
public override void Heuristic(in ActionBuffers actionsOut)
{
Debug.Log("Heuristic");
ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
continuousActions[0] = Input.GetAxisRaw("Horizontal");
continuousActions[1] = Input.GetAxisRaw("Vertical");
}
public override void OnEpisodeBegin()
{
Debug.Log("OnEpisodeBegin");
ResetScene();
}
void ResetScene()
{
Debug.Log("ResetScene");
transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
transform.position = startingPosition;
}
public override void CollectObservations(VectorSensor sensor)
{
Debug.Log("CollectObservations");
sensor.AddObservation(transform.position);
sensor.AddObservation(transform.rotation.z);
sensor.AddObservation(transform.rotation.x);
}
}