Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ For more information on Ray see http://ray.readthedocs.io/en/latest/.
First start Ray by executing a command of the following form:

```
ray start --head --redis-port=6379 --num-workers=18
ray start --head
```
This command starts multiple Python processes on one machine for parallel computations with Ray.
Set "num_workers=X" for parallelizing ARS across X CPUs.
Set "--num_cpus=X" for parallelizing ARS across X CPUs.
For parallelzing ARS on a cluster follow the instructions here: http://ray.readthedocs.io/en/latest/using-ray-on-a-large-cluster.html.

We recommend using single threaded linear algebra computations by setting:
Expand Down
6 changes: 3 additions & 3 deletions code/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def do_rollouts(self, w_policy, num_rollouts = 1, shift = 1, evaluate = False):

# for evaluation we do not shift the rewards (shift = 0) and we use the
# default rollout length (1000 for the MuJoCo locomotion tasks)
reward, r_steps = self.rollout(shift = 0., rollout_length = self.env.spec.timestep_limit)
reward, r_steps = self.rollout(shift = 0., rollout_length = self.env.spec.max_episode_steps)
rollout_rewards.append(reward)

else:
Expand Down Expand Up @@ -386,7 +386,7 @@ def run_ars(params):
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='HalfCheetah-v1')
parser.add_argument('--env_name', type=str, default='HalfCheetah-v2')
parser.add_argument('--n_iter', '-n', type=int, default=1000)
parser.add_argument('--n_directions', '-nd', type=int, default=8)
parser.add_argument('--deltas_used', '-du', type=int, default=8)
Expand All @@ -407,7 +407,7 @@ def run_ars(params):
parser.add_argument('--filter', type=str, default='MeanStdFilter')

local_ip = socket.gethostbyname(socket.gethostname())
ray.init(redis_address= local_ip + ':6379')
ray.init(address= 'auto')

args = parser.parse_args()
params = vars(args)
Expand Down