diff --git a/dopamine/agents/dqn/dqn_agent.py b/dopamine/agents/dqn/dqn_agent.py index 2dafafb1..ee2f28ac 100644 --- a/dopamine/agents/dqn/dqn_agent.py +++ b/dopamine/agents/dqn/dqn_agent.py @@ -101,7 +101,12 @@ def __init__(self, centered=True), summary_writer=None, summary_writing_frequency=500, - allow_partial_reload=False): + allow_partial_reload=False, + reset_period=None, + reset_dense1=False, + reset_dense2=False, + reset_last_layer=False, + reset_max=3): """Initializes the agent and constructs the components of its graph. Args: @@ -182,6 +187,16 @@ def __init__(self, self.eval_mode = eval_mode self.training_steps = 0 self.optimizer = optimizer + # Modified + self.optimizer_state = self.optimizer.variables() + + self.reset_period = reset_period + self.reset_dense1 = reset_dense1 + self.reset_dense2 = reset_dense2 + self.reset_last_layer = reset_last_layer + self.reset_max = reset_max + self.reset_counter = 0 + tf.compat.v1.disable_v2_behavior() if isinstance(summary_writer, str): # If we're passing in directory name. self.summary_writer = tf.compat.v1.summary.FileWriter(summary_writer) @@ -210,6 +225,8 @@ def __init__(self, self._build_networks() + self.online_convnet_state = self.online_convnet.get_weights() + self._train_op = self._build_train_op() self._sync_qt_ops = self._build_sync_op() @@ -231,6 +248,7 @@ def __init__(self, self.summary_writer.add_graph(graph=tf.compat.v1.get_default_graph()) self._sess.run(tf.compat.v1.global_variables_initializer()) + def _create_network(self, name): """Builds the convolutional network used to compute the agent's Q-values. @@ -450,8 +468,48 @@ def _train_step(self): if self.training_steps % self.target_update_period == 0: self._sess.run(self._sync_qt_ops) + if (self.reset_period is not None and + self.training_steps % self.reset_period == 0\ + and self.reset_counter < self.reset_max): + print("Resetting last layers...") + self.ResetWeights() + self.training_steps += 1 + def ResetWeights(self): + # Reset the weights of the last layer + # self.online_convnet.set_weights(self.online_convnet_state) + # self.target_convnet.set_weights(self.online_convnet_state) + if self.reset_counter >= self.reset_max: + return + + print("Resetting weights...") + if self.reset_last_layer: + print("Resetting last layer!") + self.online_convnet.layers[-1].last_layer.kernel.initializer.run(session=self._sess) + self.online_convnet.layers[-1].last_layer.bias.initializer.run(session=self._sess) + + if self.reset_dense1: + print("Resetting dense1 layer!") + self.online_convnet.layers[-1].dense1.kernel.initializer.run(session=self._sess) + self.online_convnet.layers[-1].dense1.bias.initializer.run(session=self._sess) + + if self.reset_dense2: + print("Resetting dense2 layer!") + self.online_convnet.layers[-1].dense2.kernel.initializer.run(session=self._sess) + self.online_convnet.layers[-1].dense2.bias.initializer.run(session=self._sess) + + # Legacy code + # self.online_convnet.last_layer.kernel.initializer.run(session=self._sess) + # self.online_convnet.last_layer.bias.initializer.run(session=self._sess) + + # self._sess.run(tf.compat.v1.global_variables_initializer()) + # Reset the optimizer state + optimizer_reset = tf.compat.v1.variables_initializer(self.optimizer_state) + self._sess.run(optimizer_reset) + + self.reset_counter += 1 + def _record_observation(self, observation): """Records an observation and update state. diff --git a/dopamine/discrete_domains/atari_lib.py b/dopamine/discrete_domains/atari_lib.py index 91a5a4ce..ddd3c8d6 100644 --- a/dopamine/discrete_domains/atari_lib.py +++ b/dopamine/discrete_domains/atari_lib.py @@ -158,6 +158,11 @@ def __init__(self, num_actions, name=None): name='fully_connected') self.dense2 = tf.keras.layers.Dense(num_actions, name='fully_connected') + # Modification + def reset_last_layer(self): + """Reset the last layer of the network.""" + self.dense2 = tf.keras.layers.Dense(self.num_actions, name='fully_connected') + def call(self, state): """Creates the output tensor/op given the state tensor as input. diff --git a/dopamine/discrete_domains/gym_lib.py b/dopamine/discrete_domains/gym_lib.py index 4ff8b9c6..b61716d0 100644 --- a/dopamine/discrete_domains/gym_lib.py +++ b/dopamine/discrete_domains/gym_lib.py @@ -114,6 +114,7 @@ def __init__(self, min_vals, max_vals, num_actions, self.num_atoms = num_atoms self.min_vals = min_vals self.max_vals = max_vals + self.activation_fn = activation_fn # Defining layers. self.flatten = tf.keras.layers.Flatten() self.dense1 = tf.keras.layers.Dense(512, activation=activation_fn, @@ -127,6 +128,22 @@ def __init__(self, min_vals, max_vals, num_actions, self.last_layer = tf.keras.layers.Dense(num_actions * num_atoms, name='fully_connected') + # Modified: saving the initial weights to load them after + # model.save_weights('model.h5') + + # Modified + def reset_layer(self, layer): + a,b = layer.get_weights()[0].shape + layer.set_weights([np.random.randn(a,b), np.ones(layer.get_weights()[1].shape)]) + + # Modified + def reset_last_layer(self): + """Reset the last layer(s) of the network.""" + self.reset_layer(self.dense1) + self.reset_layer(self.dense2) + self.reset_layer(self.last_layer) + + def call(self, state): """Creates the output tensor/op given the state tensor as input.""" x = tf.cast(state, tf.float32) @@ -158,6 +175,10 @@ def __init__(self, num_actions, name=None): self.net = BasicDiscreteDomainNetwork( CARTPOLE_MIN_VALS, CARTPOLE_MAX_VALS, num_actions) + # Modified + def reset_last_layer(self): + self.net.reset_last_layer() + def call(self, state): """Creates the output tensor/op given the state tensor as input.""" x = self.net(state) diff --git a/dopamine/discrete_domains/run_experiment.py b/dopamine/discrete_domains/run_experiment.py index 7e5d9ffc..a78b3e5f 100644 --- a/dopamine/discrete_domains/run_experiment.py +++ b/dopamine/discrete_domains/run_experiment.py @@ -173,7 +173,8 @@ def __init__(self, max_steps_per_episode=27000, clip_rewards=True, use_legacy_logger=True, - fine_grained_print_to_console=True): + fine_grained_print_to_console=True, + reset_period=None): """Initialize the Runner object in charge of running a full experiment. Args: @@ -234,6 +235,8 @@ def __init__(self, self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix) + self._reset_period = reset_period + # Create a collector dispatcher for metrics reporting. self._collector_dispatcher = collector_dispatcher.CollectorDispatcher( self._base_dir) @@ -603,6 +606,11 @@ def run_experiment(self): return for iteration in range(self._start_iteration, self._num_iterations): + # Modified: Check if the reset period is reached, and if so, reset the weights. + if (self._reset_period is not None and + iteration != 0 and iteration % self._reset_period == 0): + self._agent.ResetWeights() + statistics = self._run_one_iteration(iteration) if self._use_legacy_logger: self._log_experiment(iteration, statistics) diff --git a/dopamine/jax/agents/sac/sac_agent.py b/dopamine/jax/agents/sac/sac_agent.py index 9eea489d..ae29e348 100644 --- a/dopamine/jax/agents/sac/sac_agent.py +++ b/dopamine/jax/agents/sac/sac_agent.py @@ -283,7 +283,8 @@ def __init__(self, summary_writing_frequency=500, allow_partial_reload=False, seed=None, - collector_allowlist=('tensorboard')): + collector_allowlist=('tensorboard'), + reset_period=None): r"""Initializes the agent and constructs the necessary components. Args: @@ -387,6 +388,9 @@ def __init__(self, self.allow_partial_reload = allow_partial_reload self._collector_allowlist = collector_allowlist + # Reset period is used to reset the agent's state every reset_period steps. + self.reset_period = reset_period + self._rng = jax.random.PRNGKey(seed) state_shape = self.observation_shape + (stack_size,) self.state = onp.zeros(state_shape)