'Deep Reinforcement Learning - CartPole Problem
I tried to implement the most simple Deep Q Learning algorithm. I think, I've implemented it right and know that Deep Q Learning struggles with divergences but the reward is declining very fast and the loss is diverging. I would be grateful if someone could help me pointing out the right hyperparameters or if I implemented the algorithm wrong. I've tried a lot of hyperparameter combinations and also changing the complexity of the QNet.
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import collections
import numpy as np
import matplotlib.pyplot as plt
import gym
from torch.nn.modules.linear import Linear
from torch.nn.modules.loss import MSELoss
class ReplayBuffer:
def __init__(self, max_replay_size, batch_size):
self.max_replay_size = max_replay_size
self.batch_size = batch_size
self.buffer = collections.deque()
def push(self, *transition):
if len(self.buffer) == self.max_replay_size:
self.buffer.popleft()
self.buffer.append(transition)
def sample_batch(self):
indices = np.random.choice(len(self.buffer), self.batch_size, replace = False)
batch = [self.buffer[index] for index in indices]
state, action, reward, next_state, done = zip(*batch)
state = np.array(state)
action = np.array(action)
reward = np.array(reward)
next_state = np.array(next_state)
done = np.array(done)
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
class QNet(nn.Module):
def __init__(self, state_dim, action_dim):
super(QNet, self).__init__()
self.linear1 = Linear(in_features = state_dim, out_features = 64)
self.linear2 = Linear(in_features = 64, out_features = action_dim)
def forward(self, x):
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
return x
def train(replay_buffer, model, target_model, discount_factor, mse, optimizer):
state, action, reward, next_state, _ = replay_buffer.sample_batch()
state, next_state = torch.tensor(state, dtype = torch.float), torch.tensor(next_state,
dtype = torch.float)
# Compute Q Value and Target Q Value
q_values = model(state).gather(1, torch.tensor(action, dtype = torch.int64).unsqueeze(-1))
with torch.no_grad():
max_next_q_values = target_model(next_state).detach().max(1)[0]
q_target_value = torch.tensor(reward, dtype = torch.float) + discount_factor *
max_next_q_values
optimizer.zero_grad()
loss = mse(q_values, q_target_value.unsqueeze(1))
loss.backward()
optimizer.step()
return loss.item()
def main():
# Define Hyperparameters and Parameters
EPISODES = 10000
MAX_REPLAY_SIZE = 10000
BATCH_SIZE = 32
EPSILON = 1.0
MIN_EPSILON = 0.05
DISCOUNT_FACTOR = 0.95
DECAY_RATE = 0.99
LEARNING_RATE = 1e-3
SYNCHRONISATION = 33
EVALUATION = 32
# Initialize Environment, Model, Target-Model, Optimizer, Loss Function and Replay Buffer
env = gym.make("CartPole-v0")
model = QNet(state_dim = env.observation_space.shape[0], action_dim =
env.action_space.n)
target_model = QNet(state_dim = env.observation_space.shape[0], action_dim =
env.action_space.n)
target_model.load_state_dict(model.state_dict())
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
mse = MSELoss()
replay_buffer = ReplayBuffer(max_replay_size = MAX_REPLAY_SIZE, batch_size = BATCH_SIZE)
while len(replay_buffer) != MAX_REPLAY_SIZE:
state = env.reset()
done = False
while done != True:
action = env.action_space.sample()
next_state, reward, done, _ = env.step(action)
replay_buffer.push(state, action, reward, next_state, done)
state = next_state
# Begin with the Main Loop where the QNet is trained
count_until_synchronisation = 0
count_until_evaluation = 0
history = {'Episode': [], 'Reward': [], 'Loss': []}
for episode in range(EPISODES):
total_reward = 0.0
total_loss = 0.0
state = env.reset()
iterations = 0
done = False
while done != True:
count_until_synchronisation += 1
count_until_evaluation += 1
# Take an action
if np.random.rand(1) < EPSILON:
action = env.action_space.sample()
else:
with torch.no_grad():
output = model(torch.tensor(state, dtype = torch.float)).numpy()
action = np.argmax(output)
# Observe new state and reward + store into replay_buffer
next_state, reward, done, _ = env.step(action)
total_reward += reward
replay_buffer.push(state, action, reward, next_state, done)
state = next_state
if count_until_synchronisation % SYNCHRONISATION == 0:
target_model.load_state_dict(model.state_dict())
if count_until_evaluation % EVALUATION == 0:
loss = train(replay_buffer = replay_buffer, model = model, target_model =
target_model, discount_factor = DISCOUNT_FACTOR,
mse = mse, optimizer = optimizer)
total_loss += loss
iterations += 1
print (f"Episode {episode} is concluded in {iterations} iterations with a total reward
of {total_reward}")
if EPSILON > MIN_EPSILON:
EPSILON *= DECAY_RATE
history['Episode'].append(episode)
history['Reward'].append(total_reward)
history['Loss'].append(total_loss)
# Plot the Loss + Reward per Episode
fig, ax = plt.subplots(figsize = (10, 6))
ax.plot(history['Episode'], history['Reward'], label = "Reward")
ax.set_xlabel('Episodes', fontsize = 15)
ax.set_ylabel('Total Reward per Episode', fontsize = 15)
plt.legend(prop = {'size': 15})
plt.show()
fig, ax = plt.subplots(figsize = (10, 6))
ax.plot(history['Episode'], history['Loss'], label = "Loss")
ax.set_xlabel('Episodes', fontsize = 15)
ax.set_ylabel('Total Loss per Episode', fontsize = 15)
plt.legend(prop = {'size': 15})
plt.show()
if __name__ == "__main__":
main()
Solution 1:[1]
Your code looks fine, I think your hyperparameters are not ideal. I would change two, potentially three things:
- If I'm not mistaken, you update your target net every 32 steps. That is way too low I think. In the original paper by Mnih et al., they do a hard update every 10k steps. Think about it: The target net is used to calculate the loss, you essentially change the loss function every 32 steps, which would be more than once per episode.
- Your replay buffer size is pretty small. I would set it to 100k or 1M, even if that is longer than what you intend to train for. If the replay buffer is too small, you will lose the older transitions, which can cause your network to "forget" things it already learned. Not sure how dramatic this is for cartpole, but maybe worth trying...
- Learning rate could also be lower, I am using 1-e4 with RMSProp. Generally changing the optimizer can also yield different results.
Hope this helps, good luck :)
Solution 2:[2]
Your code looks fine and well written, the hyperparams seem reasonable (except maybe for the update frequency which may be too low), I think that the Q network is quite small with a single dense layer.
A deeper model could likely do better (probably not more than 3-4 layers though), but you said that you already tried different network sizes.
Another thing that comes to mind is the target update. You are doing an hard update every n steps; a soft update may help a bit, but I wouldn't count on it.
You can also try lowering the learning rate a bit, but I imagine you already did that.
My suggestions are:
- try less frequent target updates
- try a larger (deeper, something like 2/3 dense layers with 32 nodes), if you haven't already
- look into soft target updates (polyak averaging and so on)
- try your implementation in other simple gym environments and check if its behavior is still the same.
Sadly DQN is not ideal and won't converge for many problems, but it should be able to solve cartpole.
Solution 3:[3]
This answer might be a little late for OP. But the implementation does not take the terminal values into account. The only signal, from some environments (ex. CartPole), that the agent is doing something wrong, is from the terminal value. In CartPole the agent gets a +1 reward for every step it takes, but it has no idea about time, so if you don't create zero value targets for terminal states then the value tries to converge towards an infinite sum of discounted +1 for every state.
The simplest way to include the zero values for terminal states is to also sample a batch of (dones) and simply multiply them with the value targets you calculated before.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | frietz58 |
Solution 2 | |
Solution 3 | JimOhm |