Skip to content

Commit

Permalink
cleanup initial commit (Denys88#258)
Browse files Browse the repository at this point in the history
* cleanup initial commit

* added smac_v2 support

* added all 3 initial

---------

Co-authored-by: Denys Makoviichuk <[email protected]>
  • Loading branch information
Denys88 and DenSumy authored Oct 20, 2023
1 parent 33ba628 commit 32f70ee
Show file tree
Hide file tree
Showing 72 changed files with 830 additions and 263 deletions.
33 changes: 19 additions & 14 deletions rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,6 @@ def create_slime_gym_env(**kwargs):
env = gym.make(name, **kwargs)
return env

def create_connect_four_env(**kwargs):
from rl_games.envs.connect4_selfplay import ConnectFourSelfPlay
name = kwargs.pop('name')
limit_steps = kwargs.pop('limit_steps', False)
self_play = kwargs.pop('self_play', False)
if self_play:
env = ConnectFourSelfPlay(name, **kwargs)
else:
env = gym.make(name, **kwargs)
return env

def create_atari_gym_env(**kwargs):
#frames = kwargs.pop('frames', 1)
Expand Down Expand Up @@ -171,6 +161,21 @@ def create_smac(name, **kwargs):
env = SMACEnv(name, **kwargs)


if frames > 1:
if has_cv:
env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=False, flatten=flatten)
else:
env = wrappers.BatchedFrameStack(env, frames, transpose=False, flatten=flatten)
return env

def create_smac_v2(name, **kwargs):
from rl_games.envs.smac_v2_env import SMACEnvV2
frames = kwargs.pop('frames', 1)
transpose = kwargs.pop('transpose', False)
flatten = kwargs.pop('flatten', True)
has_cv = kwargs.get('central_value', False)
env = SMACEnvV2(name, **kwargs)

if frames > 1:
if has_cv:
env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=False, flatten=flatten)
Expand Down Expand Up @@ -359,6 +364,10 @@ def create_env(name, **kwargs):
'env_creator' : lambda **kwargs : create_smac(**kwargs),
'vecenv_type' : 'RAY'
},
'smac_v2' : {
'env_creator' : lambda **kwargs : create_smac_v2(**kwargs),
'vecenv_type' : 'RAY'
},
'smac_cnn' : {
'env_creator' : lambda **kwargs : create_smac_cnn(**kwargs),
'vecenv_type' : 'RAY'
Expand Down Expand Up @@ -391,10 +400,6 @@ def create_env(name, **kwargs):
'env_creator' : lambda **kwargs : create_minigrid_env(kwargs.pop('name'), **kwargs),
'vecenv_type' : 'RAY'
},
'connect4_env' : {
'env_creator' : lambda **kwargs : create_connect_four_env(**kwargs),
'vecenv_type' : 'RAY'
},
'multiwalker_env' : {
'env_creator' : lambda **kwargs : create_multiwalker_env(**kwargs),
'vecenv_type' : 'RAY'
Expand Down
4 changes: 2 additions & 2 deletions rl_games/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, size, ob_space):
self._next_obses = np.zeros((size,) + ob_space.shape, dtype=ob_space.dtype)
self._rewards = np.zeros(size)
self._actions = np.zeros(size, dtype=np.int32)
self._dones = np.zeros(size, dtype=np.bool)
self._dones = np.zeros(size, dtype=bool)

self._maxsize = size
self._next_idx = 0
Expand Down Expand Up @@ -341,7 +341,7 @@ def _init_from_env_info(self, env_info):
if self.is_discrete or self.is_multi_discrete:
self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=int), obs_base_shape)
if self.use_action_masks:
self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape + (np.sum(self.actions_num),), dtype=np.bool), obs_base_shape)
self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape + (np.sum(self.actions_num),), dtype=bool), obs_base_shape)
if self.is_continuous:
self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape)
self.tensor_dict['mus'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape)
Expand Down
25 changes: 13 additions & 12 deletions rl_games/common/vecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from time import sleep
import torch


class RayWorker:
def __init__(self, config_name, config):
self.env = configurations[config_name]['env_creator'](**config)
Expand Down Expand Up @@ -96,30 +95,32 @@ def get_env_info(self):


class RayVecEnv(IVecEnv):
import ray

def __init__(self, config_name, num_actors, **kwargs):
self.config_name = config_name
self.num_actors = num_actors
self.use_torch = False
self.seed = kwargs.pop('seed', None)

import ray
self.remote_worker = ray.remote(RayWorker)

self.remote_worker = self.ray.remote(RayWorker)
self.workers = [self.remote_worker.remote(self.config_name, kwargs) for i in range(self.num_actors)]

if self.seed is not None:
seeds = range(self.seed, self.seed + self.num_actors)
seed_set = []
for (seed, worker) in zip(seeds, self.workers):
seed_set.append(worker.seed.remote(seed))
ray.get(seed_set)
self.ray.get(seed_set)

res = self.workers[0].get_number_of_agents.remote()
self.num_agents = ray.get(res)
self.num_agents = self.ray.get(res)

res = self.workers[0].get_env_info.remote()
env_info = ray.get(res)
env_info = self.ray.get(res)
res = self.workers[0].can_concat_infos.remote()
can_concat_infos = ray.get(res)
can_concat_infos = self.ray.get(res)
self.use_global_obs = env_info['use_global_observations']
self.concat_infos = can_concat_infos
self.obs_type_dict = type(env_info.get('observation_space')) is gym.spaces.Dict
Expand All @@ -139,7 +140,7 @@ def step(self, actions):
for num, worker in enumerate(self.workers):
res_obs.append(worker.step.remote(actions[self.num_agents * num: self.num_agents * num + self.num_agents]))

all_res = ray.get(res_obs)
all_res = self.ray.get(res_obs)
for res in all_res:
cobs, crewards, cdones, cinfos = res
if self.use_global_obs:
Expand Down Expand Up @@ -171,27 +172,27 @@ def step(self, actions):

def get_env_info(self):
res = self.workers[0].get_env_info.remote()
return ray.get(res)
return self.ray.get(res)

def set_weights(self, indices, weights):
res = []
for ind in indices:
res.append(self.workers[ind].set_weights.remote(weights))
ray.get(res)
self.ray.get(res)

def has_action_masks(self):
return True

def get_action_masks(self):
mask = [worker.get_action_mask.remote() for worker in self.workers]
masks = ray.get(mask)
masks = self.ray.get(mask)
return np.concatenate(masks, axis=0)

def reset(self):
res_obs = [worker.reset.remote() for worker in self.workers]
newobs, newstates = [],[]
for res in res_obs:
cobs = ray.get(res)
cobs = self.ray.get(res)
if self.use_global_obs:
newobs.append(cobs["obs"])
newstates.append(cobs["state"])
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
69 changes: 69 additions & 0 deletions rl_games/configs/smac/v2/env_configs/sc2_gen_protoss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
env: sc2wrapped

env_args:
continuing_episode: False
difficulty: "7"
game_version: null
map_name: "10gen_protoss"
move_amount: 2
obs_all_health: True
obs_instead_of_state: False
obs_last_action: False
obs_own_health: True
obs_pathing_grid: False
obs_terrain_height: False
obs_timestep_number: False
reward_death_value: 10
reward_defeat: 0
reward_negative_scale: 0.5
reward_only_positive: True
reward_scale: True
reward_scale_rate: 20
reward_sparse: False
reward_win: 200
replay_dir: ""
replay_prefix: ""
conic_fov: False
use_unit_ranges: True
min_attack_range: 2
obs_own_pos: True
num_fov_actions: 12
capability_config:
n_units: 5
n_enemies: 5
team_gen:
dist_type: "weighted_teams"
unit_types:
- "stalker"
- "zealot"
- "colossus"
weights:
- 0.45
- 0.45
- 0.1
observe: True
start_positions:
dist_type: "surrounded_and_reflect"
p: 0.5
map_x: 32
map_y: 32

# enemy_mask:
# dist_type: "mask"
# mask_probability: 0.5
# n_enemies: 5
state_last_action: True
state_timestep_number: False
step_mul: 8
heuristic_ai: False
# heuristic_rest: False
debug: False
prob_obs_enemy: 1.0
action_mask: True

test_nepisode: 32
test_interval: 10000
log_interval: 2000
runner_log_interval: 2000
learner_log_interval: 2000
t_max: 10050000
70 changes: 70 additions & 0 deletions rl_games/configs/smac/v2/env_configs/sc2_gen_protoss_epo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
env: sc2wrapped

env_args:
continuing_episode: False
difficulty: "7"
game_version: null
map_name: "10gen_protoss"
move_amount: 2
obs_all_health: True
obs_instead_of_state: False
obs_last_action: False
obs_own_health: True
obs_pathing_grid: False
obs_terrain_height: False
obs_timestep_number: False
reward_death_value: 10
reward_defeat: 0
reward_negative_scale: 0.5
reward_only_positive: True
reward_scale: True
reward_scale_rate: 20
reward_sparse: False
reward_win: 200
replay_dir: ""
replay_prefix: ""
conic_fov: False
use_unit_ranges: True
min_attack_range: 2
obs_own_pos: True
num_fov_actions: 12
capability_config:
n_units: 5
n_enemies: 5
team_gen:
dist_type: "weighted_teams"
unit_types:
- "stalker"
- "zealot"
- "colossus"
weights:
- 0.45
- 0.45
- 0.1
observe: True
start_positions:
dist_type: "surrounded_and_reflect"
p: 0.5
map_x: 32
map_y: 32

# enemy_mask:
# dist_type: "mask"
# mask_probability: 0.5
# n_enemies: 5
state_last_action: True
state_timestep_number: False
step_mul: 8
heuristic_ai: False
# heuristic_rest: False
debug: False
# Most severe partial obs setting:
prob_obs_enemy: 0.0
action_mask: False

test_nepisode: 32
test_interval: 10000
log_interval: 2000
runner_log_interval: 2000
learner_log_interval: 2000
t_max: 10050000
71 changes: 71 additions & 0 deletions rl_games/configs/smac/v2/env_configs/sc2_gen_terran.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
env: sc2wrapped

env_args:
continuing_episode: False
difficulty: "7"
game_version: null
map_name: "10gen_terran"
move_amount: 2
obs_all_health: True
obs_instead_of_state: False
obs_last_action: False
obs_own_health: True
obs_pathing_grid: False
obs_terrain_height: False
obs_timestep_number: False
reward_death_value: 10
reward_defeat: 0
reward_negative_scale: 0.5
reward_only_positive: True
reward_scale: True
reward_scale_rate: 20
reward_sparse: False
reward_win: 200
replay_dir: ""
replay_prefix: ""
conic_fov: False
obs_own_pos: True
use_unit_ranges: True
min_attack_range: 2
num_fov_actions: 12
capability_config:
n_units: 5
n_enemies: 5
team_gen:
dist_type: "weighted_teams"
unit_types:
- "marine"
- "marauder"
- "medivac"
weights:
- 0.45
- 0.45
- 0.1
exception_unit_types:
- "medivac"
observe: True

start_positions:
dist_type: "surrounded_and_reflect"
p: 0.5
map_x: 32
map_y: 32
# enemy_mask:
# dist_type: "mask"
# mask_probability: 0.5
# n_enemies: 5
state_last_action: True
state_timestep_number: False
step_mul: 8
heuristic_ai: False
# heuristic_rest: False
debug: False
prob_obs_enemy: 1.0
action_mask: True

test_nepisode: 32
test_interval: 10000
log_interval: 2000
runner_log_interval: 2000
learner_log_interval: 2000
t_max: 10050000
Loading

0 comments on commit 32f70ee

Please sign in to comment.