Skip to content

Commit

Permalink
fix a bug in acer saving and loading model (openai#990)
Browse files Browse the repository at this point in the history
  • Loading branch information
DylanHaiyangChen authored and pzhokhov committed Sep 27, 2019
1 parent 5379729 commit f703776
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions baselines/acer/acer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from baselines.common import set_global_seeds
from baselines.common.policies import build_policy
from baselines.common.tf_util import get_session, save_variables
from baselines.common.tf_util import get_session, save_variables, load_variables
from baselines.common.vec_env.vec_frame_stack import VecFrameStack

from baselines.a2c.utils import batch_to_seq, seq_to_batch
Expand Down Expand Up @@ -216,7 +216,8 @@ def _step(observation, **kwargs):


self.train = train
self.save = functools.partial(save_variables, sess=sess, variables=params)
self.save = functools.partial(save_variables, sess=sess)
self.load = functools.partial(load_variables, sess=sess)
self.train_model = train_model
self.step_model = step_model
self._step = _step
Expand Down Expand Up @@ -358,6 +359,9 @@ def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=
total_timesteps=total_timesteps, lrschedule=lrschedule, c=c,
trust_region=trust_region, alpha=alpha, delta=delta)

if load_path is not None:
model.load(load_path)

runner = Runner(env=env, model=model, nsteps=nsteps)
if replay_ratio > 0:
buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size)
Expand Down

0 comments on commit f703776

Please sign in to comment.