-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathenv_gym.py
More file actions
121 lines (98 loc) · 4.47 KB
/
env_gym.py
File metadata and controls
121 lines (98 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
'''
Gymnasium class for the CARLAEnv. Establishes communication with the env_agent and serves as gymnasium interface.
'''
import os
import math
import pathlib
import gymnasium as gym
from gymnasium import spaces
import zmq
import numpy as np
class CARLAEnv(gym.Env):
'''
Gymnasium environment class interface. Handles communication with env_agent.py
'''
metadata = {'render_modes': ['rgb_array']}
def __init__(self, port, config, render_mode='rgb_array'): # pylint: disable=locally-disabled, unused-argument
self.num_recv = 0
self.observation_space = spaces.Dict({
'bev_semantics':
spaces.Box(0,
255,
shape=(config.obs_num_channels, config.bev_semantics_height, config.bev_semantics_width),
dtype=np.uint8),
'measurements':
spaces.Box(-math.inf, math.inf, shape=(config.obs_num_measurements,), dtype=np.float32),
'value_measurements':
spaces.Box(-math.inf, math.inf, shape=(config.num_value_measurements,), dtype=np.float32)
})
self.action_space = spaces.Box(config.action_space_min,
config.action_space_max,
shape=(config.action_space_dim,),
dtype=np.float32)
self.metadata['render_fps'] = config.frame_rate
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PAIR)
self.port = port
self.initialized = False
self.config = config
def reset(self, seed=None, options=None): # pylint: disable=locally-disabled, unused-argument
# We need the following line to seed self.np_random
super().reset(seed=seed)
if not self.initialized:
# Connect to env_agent.
current_folder = pathlib.Path(__file__).parent.resolve()
comm_folder = os.path.join(current_folder, 'comm_files')
pathlib.Path(comm_folder).mkdir(parents=True, exist_ok=True)
communication_file = os.path.join(comm_folder, str(self.port))
self.socket.bind(f'ipc://{communication_file}.lock')
print(f'Connecting to leaderboard gym, port: {self.port}')
msg = self.socket.recv_string()
print(msg)
self.initialized = True
data = self.socket.recv_multipart(copy=False)
self.num_recv += 1
observation = {
'bev_semantics':
np.frombuffer(data[0],
dtype=np.uint8).reshape(self.config.obs_num_channels, self.config.bev_semantics_height,
self.config.bev_semantics_width),
'measurements':
np.frombuffer(data[1], dtype=np.float32),
'value_measurements':
np.frombuffer(data[2], dtype=np.float32)
}
info = {'n_steps': np.frombuffer(data[6], dtype=np.int32), 'suggest': np.frombuffer(data[7], dtype=np.int32)}
num_sent = np.frombuffer(data[8], dtype=np.uint64).item()
if self.num_recv != num_sent:
raise ValueError(f"Communication breakdown, Leaderboard send more frames than client consumed."
f"num_recv: {self.num_recv}, num_sent: {num_sent}")
return observation, info
def step(self, action):
self.socket.send(action.tobytes(), copy=False)
data = self.socket.recv_multipart(copy=False)
self.num_recv += 1
observation = {
'bev_semantics':
np.frombuffer(data[0],
dtype=np.uint8).reshape(self.config.obs_num_channels, self.config.bev_semantics_height,
self.config.bev_semantics_width),
'measurements':
np.frombuffer(data[1], dtype=np.float32),
'value_measurements':
np.frombuffer(data[2], dtype=np.float32)
}
reward = np.frombuffer(data[3], dtype=np.float32).item()
termination = np.frombuffer(data[4], dtype=bool).item() # True if agent ended in destroy method.
truncation = np.frombuffer(data[5], dtype=bool).item() # True if agent timed out.
info = {
'n_steps': np.frombuffer(data[6], dtype=np.int32).item(),
'suggest': np.frombuffer(data[7], dtype=np.int32).item()
}
num_sent = np.frombuffer(data[8], dtype=np.uint64).item()
if self.num_recv != num_sent:
raise ValueError(f"Communication breakdown, Leaderboard send more frames than client consumed."
f"num_recv: {self.num_recv}, num_sent: {num_sent}")
return observation, reward, termination, truncation, info
def close(self):
print('Called close!')