using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using TMPro;
public class Env : MonoBehaviour
{
public GameObject playerGo;
public GameObject targetGo;
public GameObject poisonGo;
private List<GameObject> potions;
public TMP_Text cumulativeRewardText;
private float potionDist;
private float playerDist;
private void Start()
{
this.potions = new List<GameObject>();
ResetEnv();
}
public void ResetEnv()
{
this.RemovePosions();
this.PlacePlayer();
this.PlacePotions();
}
private void Update()
{
this.cumulativeRewardText.text = this.playerGo.GetComponent<PlayerAgent>().GetCumulativeReward().ToString("0.00");
}
public void PlacePotions()
{
var targetGo = Instantiate(this.targetGo, this.transform);
var poisonGo = Instantiate(this.poisonGo, this.transform);
targetGo.transform.localPosition = new Vector3(Random.Range(-4, 4), 0.1f, Random.Range(-4, 4));
poisonGo.transform.localPosition = new Vector3(Random.Range(-4, 4), 0.1f, Random.Range(-4, 4));
while (true)
{
this.potionDist = Vector3.Distance(targetGo.transform.localPosition, poisonGo.transform.localPosition);
this.playerDist = Vector3.Distance(this.playerGo.transform.localPosition, poisonGo.transform.localPosition);
if (this.potionDist < 2 || playerDist < 2) poisonGo.transform.localPosition = new Vector3(Random.Range(-4, 4), 0.1f, Random.Range(-4, 4));
else break;
}
this.potions.Add(targetGo);
this.potions.Add(poisonGo);
this.playerGo.GetComponent<PlayerAgent>().target = targetGo.transform;
this.playerGo.GetComponent<PlayerAgent>().poison = poisonGo.transform;
}
public void RemovePosions()
{
for(int i = 0; i < this.potions.Count; i++)
{
Destroy(this.potions[i]);
}
}
public void PlacePlayer()
{
this.playerGo.transform.localPosition = new Vector3(0, 0.1f, 0);
}
//IEnumerator ResetRoutine()
//{
// while (true)
// {
// }
//}
//void OnDrawGizmos()
//{
// Gizmos.DrawCube(this.transform.position, new Vector3 (10f, 0.1f, 10f));
//}
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class PlayerAgent : Agent
{
public float moveSpeed = 5f;
public float turnSpeed = 180f;
public Animator anim;
public Env env;
public Transform target;
public Transform poison;
public override void OnEpisodeBegin()
{
this.env.ResetEnv();
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(this.target.localPosition);
sensor.AddObservation(this.poison.localPosition);
sensor.AddObservation(this.transform.localPosition);
}
public override void OnActionReceived(ActionBuffers actions)
{
var action = actions.ContinuousActions;
Vector3 controlSignal = Vector3.zero;
controlSignal.x = action[0];
controlSignal.z = action[1];
this.transform.rotation = Quaternion.LookRotation(controlSignal * this.turnSpeed);
if (controlSignal == Vector3.zero)
{
this.anim.SetInteger("State", 0);
}
else
{
this.anim.SetInteger("State", 1);
this.transform.Translate(Vector3.forward * this.moveSpeed * Time.deltaTime);
}
float distanceToTarget = Vector3.Distance(this.transform.localPosition, this.target.localPosition);
float distanceToPoison = Vector3.Distance(this.transform.localPosition, this.poison.localPosition);
if (distanceToTarget < 1.4f)
{
SetReward(1.0f);
EndEpisode();
}
if (distanceToPoison < 1.4f)
{
AddReward(-1.0f);
EndEpisode();
}
//if (distanceToTarget > 5.0f)
//{
// SetReward(-0.01f);
//}
//if (distanceToTarget > 3.0f)
//{
// SetReward(-0.001f);
//}
//AddReward(-0.001f);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = Input.GetAxis("Horizontal");
continuousActionsOut[1] = Input.GetAxis("Vertical");
}
}
mlagents-learn ./Potion.yaml --run-id=Potion
'Unity3D' 카테고리의 다른 글
Unity) [유니티 3D] 점프 애니메이션 제어 (0) | 2023.07.04 |
---|---|
Unity) [유니티 3D] TPS 카메라, 플레이어 조작 구현 (0) | 2023.07.03 |
Unity) [유니티 3D] CharacterController를 통한 이동, 중력 적용 (0) | 2023.06.27 |
Unity) [유니티 3D] 오브젝트 생성, 이동 범위 지정하기 (0) | 2023.03.28 |
Unity) [ml-agents] RollerBall AI 학습시키기 (0) | 2023.03.27 |