using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using TMPro;

public class Env : MonoBehaviour
{
    public GameObject playerGo;
    public GameObject targetGo;
    public GameObject poisonGo;

    private List<GameObject> potions;

    public TMP_Text cumulativeRewardText;

    private float potionDist;
    private float playerDist;


    private void Start()
    {
        this.potions = new List<GameObject>();
        ResetEnv();
    }

    public void ResetEnv()
    {
        this.RemovePosions();

        this.PlacePlayer();
        this.PlacePotions();
    }

    private void Update()
    {
        this.cumulativeRewardText.text = this.playerGo.GetComponent<PlayerAgent>().GetCumulativeReward().ToString("0.00");
    }

    public void PlacePotions()
    {
        var targetGo = Instantiate(this.targetGo, this.transform);
        var poisonGo = Instantiate(this.poisonGo, this.transform);

        targetGo.transform.localPosition = new Vector3(Random.Range(-4, 4), 0.1f, Random.Range(-4, 4));
        poisonGo.transform.localPosition = new Vector3(Random.Range(-4, 4), 0.1f, Random.Range(-4, 4));

        while (true)
        {
            this.potionDist = Vector3.Distance(targetGo.transform.localPosition, poisonGo.transform.localPosition);
            this.playerDist = Vector3.Distance(this.playerGo.transform.localPosition, poisonGo.transform.localPosition);
            if (this.potionDist < 2 || playerDist < 2) poisonGo.transform.localPosition = new Vector3(Random.Range(-4, 4), 0.1f, Random.Range(-4, 4));
            else break;
        }     

        this.potions.Add(targetGo);
        this.potions.Add(poisonGo);        

        this.playerGo.GetComponent<PlayerAgent>().target = targetGo.transform;
        this.playerGo.GetComponent<PlayerAgent>().poison = poisonGo.transform;
    }

    public void RemovePosions()
    {
        for(int i = 0; i < this.potions.Count; i++)
        {
            Destroy(this.potions[i]);
        }
    }

    public void PlacePlayer()
    {
        this.playerGo.transform.localPosition = new Vector3(0, 0.1f, 0);
    }

    //IEnumerator ResetRoutine()
    //{
    //    while (true)
    //    {

    //    }
    //}

    //void OnDrawGizmos()
    //{
    //    Gizmos.DrawCube(this.transform.position, new Vector3 (10f, 0.1f, 10f));
    //}
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;

public class PlayerAgent : Agent
{
    public float moveSpeed = 5f;
    public float turnSpeed = 180f;
    public Animator anim;
    public Env env;

    public Transform target;
    public Transform poison;

    public override void OnEpisodeBegin()
    {
        this.env.ResetEnv();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(this.target.localPosition);
        sensor.AddObservation(this.poison.localPosition);
        sensor.AddObservation(this.transform.localPosition);
    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.ContinuousActions;

        Vector3 controlSignal = Vector3.zero;
        controlSignal.x = action[0];
        controlSignal.z = action[1];
        
        this.transform.rotation = Quaternion.LookRotation(controlSignal * this.turnSpeed);

        if (controlSignal == Vector3.zero)
        {
            this.anim.SetInteger("State", 0);
        }
        else
        {
            this.anim.SetInteger("State", 1);
            this.transform.Translate(Vector3.forward * this.moveSpeed * Time.deltaTime);
        }
        

        float distanceToTarget = Vector3.Distance(this.transform.localPosition, this.target.localPosition);
        float distanceToPoison = Vector3.Distance(this.transform.localPosition, this.poison.localPosition);

        if (distanceToTarget < 1.4f)
        {
            SetReward(1.0f);
            EndEpisode();
        }

        if (distanceToPoison < 1.4f)
        {
            AddReward(-1.0f);
            EndEpisode();
        }

        //if (distanceToTarget > 5.0f)
        //{
        //    SetReward(-0.01f);
        //}

        //if (distanceToTarget > 3.0f)
        //{
        //    SetReward(-0.001f);
        //}

        //AddReward(-0.001f);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var continuousActionsOut = actionsOut.ContinuousActions;
        continuousActionsOut[0] = Input.GetAxis("Horizontal");
        continuousActionsOut[1] = Input.GetAxis("Vertical");
    }

}

 

 

mlagents-learn ./Potion.yaml --run-id=Potion

+ Recent posts