-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
Copy pathtest_rpc_communicator.py
96 lines (80 loc) · 3.27 KB
/
test_rpc_communicator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from unittest.mock import Mock
import pytest
from unittest import mock
import grpc
import mlagents_envs.rpc_communicator
from mlagents_envs.rpc_communicator import RpcCommunicator
from mlagents_envs.exception import (
UnityWorkerInUseException,
UnityTimeOutException,
UnityEnvironmentException,
)
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
@pytest.mark.parametrize("n_ports", [1])
def test_rpc_communicator_checks_port_on_create(base_port: int) -> None:
first_comm = RpcCommunicator(base_port=base_port)
with pytest.raises(UnityWorkerInUseException):
second_comm = RpcCommunicator(base_port=base_port)
second_comm.close()
first_comm.close()
@pytest.mark.parametrize("n_ports", [2])
def test_rpc_communicator_close(base_port: int) -> None:
# Ensures it is possible to open a new RPC Communicators
# after closing one on the same worker_id
first_comm = RpcCommunicator(base_port=base_port)
first_comm.close()
second_comm = RpcCommunicator(base_port=base_port + 1)
second_comm.close()
@pytest.mark.parametrize("n_ports", [2])
def test_rpc_communicator_create_multiple_workers(base_port: int) -> None:
# Ensures multiple RPC communicators can be created with
# different worker_ids without causing an error.
first_comm = RpcCommunicator(base_port=base_port)
second_comm = RpcCommunicator(base_port=base_port, worker_id=1)
first_comm.close()
second_comm.close()
@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_OK(
mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = True
input = UnityInputProto()
comm.initialize(input)
comm.unity_to_external.parent_conn.poll.assert_called()
@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_timeout(
mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
comm = RpcCommunicator(timeout_wait=0.25, base_port=base_port)
comm.unity_to_external.parent_conn.poll.return_value = None
input = UnityInputProto()
# Expect a timeout
with pytest.raises(UnityTimeOutException):
comm.initialize(input)
comm.unity_to_external.parent_conn.poll.assert_called()
@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_callback(
mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
def callback():
raise UnityEnvironmentException
comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = None
input = UnityInputProto()
# Expect a timeout
with pytest.raises(UnityEnvironmentException):
comm.initialize(input, poll_callback=callback)
comm.unity_to_external.parent_conn.poll.assert_called()