-
Notifications
You must be signed in to change notification settings - Fork 0
/
blue_move.cs
92 lines (80 loc) · 2.86 KB
/
blue_move.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using System.Collections;
using UnityEngine.UI;
// RollerAgent
public class blue_move : Agent
{
Rigidbody rBody;
public GameObject enemy;
public GameObject score;
// 初期化時に呼ばれる
public override void Initialize()
{
this.rBody = GetComponent<Rigidbody>();
}
// 観察取得時に呼ばれる
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(this.transform.localPosition);
sensor.AddObservation(this.transform.localRotation);
sensor.AddObservation(enemy.transform.localPosition);
sensor.AddObservation(enemy.transform.localRotation);
}
// 行動実行時に呼ばれる
public override void OnActionReceived(ActionBuffers actions) {
Vector3 dirToGo = Vector3.zero;
int action = actions.DiscreteActions[0];
if (this.CompareTag("movable"))
{
if (action == 1) dirToGo = transform.forward;
if (action == 2) dirToGo = transform.forward * -1.0f;
if (action == 3) dirToGo = transform.up;
if (action == 4) dirToGo = transform.up * -1.0f;
rBody.AddForce(dirToGo * 10f, ForceMode.VelocityChange);
if (action == 5)
{
this.tag = "on_floor";
}
if (action == 6)
{
this.tag = "attack";
}
}
}
void Update()
{
transform.LookAt(enemy.transform, Vector3.forward);
transform.Rotate(new Vector3(-180f, -180f, +90f));
rBody.velocity = Vector3.zero;
}
//OnTriggerEnter関数
//接触したオブジェクトが引数otherとして渡される
void OnTriggerEnter(Collider other)
{
//接触したオブジェクトのタグ
if (other.CompareTag("attack"))
{
score.tag = "win_red";
this.AddReward(-1f);
EndEpisode();
}
}
// ヒューリスティックモードの行動決定時に呼ばれる
public override void Heuristic(in ActionBuffers actionsOut)
{
var actions = actionsOut.DiscreteActions;
actions[0] = 0;
if (Input.GetKey(KeyCode.UpArrow)) actions[0] = 1;
if (Input.GetKey(KeyCode.DownArrow)) actions[0] = 2;
if (Input.GetKey(KeyCode.LeftArrow)) actions[0] = 3;
if (Input.GetKey(KeyCode.RightArrow)) actions[0] = 4;
if (Input.GetKey(KeyCode.Space)) actions[0] = 5;
if (Input.GetKey(KeyCode.S)) actions[0] = 6;
}
}
//mlagents-learn config/exvs.yaml --run-id=exvs --env=apps/exvs --height=900 --width=1600 --force