PPO#
from typing import Self, Optional
from collections import namedtuple, deque
from dataclasses import dataclass
from itertools import count
import logging
import numpy as np
from scipy import stats
import gymnasium as gym
from gymnasium import Env
from gymnasium.wrappers.record_video import RecordVideo
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F
from torch.distributions import Distribution, Normal
from torch.optim import Optimizer, Adam
from IPython.display import Video
torch.manual_seed(42)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s"
)
env = gym.make("Pendulum-v1")
PPO Algorithm#
The following snapshot shows the pseudocode of the PPO-Clip algorithm:
Environment Specifications#
def get_state_dim(env: Env) -> Optional[int]:
try:
return env.observation_space.shape[0]
except:
return None
def get_action_dim(env: Env) -> Optional[int]:
try:
return env.action_space.shape[0]
except:
return None
print(f"state space dim: {get_state_dim(env)}")
print(f"action space dim: {get_action_dim(env)}")
state space dim: 3
action space dim: 1
@dataclass
class ActionSpec:
dim: int
min: np.ndarray
max: np.ndarray
@classmethod
def from_env(cls, env: Env) -> Self:
return cls(
dim=env.action_space.shape[0],
min=env.action_space.__dict__["low"],
max=env.action_space.__dict__["high"]
)
state_dim = get_state_dim(env)
action_spec = ActionSpec.from_env(env)
print(f"state dim: {state_dim}")
print(f"action spec: {action_spec}")
state dim: 3
action spec: ActionSpec(dim=1, min=array([-2.], dtype=float32), max=array([2.], dtype=float32))
Actor and Critic Models#
class MLP(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int
) -> None:
super().__init__()
self.layer1 = nn.Linear(in_dim, 64)
self.layer2 = nn.Linear(64, 64)
self.layer3 = nn.Linear(64, out_dim)
def forward(self, state: Tensor) -> Tensor:
x = self.layer1(state)
x = F.relu(x)
x = self.layer2(x)
x = F.relu(x)
x = self.layer3(x)
return x
class Actor(nn.Module):
def __init__(
self,
state_dim: int,
action_spec: ActionSpec
) -> None:
super().__init__()
self._action_min = torch.tensor(action_spec.min, dtype=torch.float)
self._action_max = torch.tensor(action_spec.max, dtype=torch.float)
self._distribution = None
self._action = None
self.mlp = MLP(
in_dim=state_dim,
out_dim=action_spec.dim
)
self.dist_param1 = nn.Sequential(
nn.Linear(action_spec.dim, 16),
nn.ReLU(inplace=True),
nn.Linear(16, action_spec.dim),
nn.Sigmoid()
)
self.dist_param2 = nn.Sequential(
nn.Linear(action_spec.dim, 16),
nn.ReLU(inplace=True),
nn.Linear(16, action_spec.dim),
nn.Sigmoid()
)
@property
def distribution(self) -> Optional[Distribution]:
return self._distribution
def forward(self, state: Tensor) -> Tensor:
out: Tensor = self.mlp(state)
param1 = self.dist_param1(out)
param2 = self.dist_param2(out)
mean = F.sigmoid(out) * (self._action_max - self._action_min) + self._action_min
# mean = param1 * (self._action_max - self._action_min) + self._action_min
# std = param2
std = torch.ones_like(mean) * 0.5
# mean = F.sigmoid(out[..., 0]) * (self._action_max - self._action_min) + self._action_min
# std = torch.ones_like(mean) * 0.5
# Generate the distribution
distribution = Normal(mean, std)
# Store the distribution
self._distribution = distribution
return distribution
def select_action(
self,
state: Optional[Tensor] = None
) -> Tensor:
if state is not None:
self.forward(state)
# Sample an action from the distribution
action = self._distribution.sample()
# Clip the action value by upper and lower bounds
action = action.clip(self._action_min, self._action_max)
# Store the action taken
self._action = action
return action
def log_prob(
self,
action: Optional[Tensor] = None,
state: Optional[Tensor] = None
) -> Tensor:
if action is None:
assert state is None,\
"the state must be set None since the action is None"
return self._distribution.log_prob(self._action).sum(dim=-1)
if state is not None:
self.forward(state)
return self._distribution.log_prob(action).sum(dim=-1)
actor = Actor(state_dim, action_spec)
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float)
action = actor.select_action(state)
log_prob = actor.log_prob()
print(f"action: {action}")
print(f"log-probability: {log_prob}")
action: tensor([0.4161])
log-probability: -0.7357034683227539
actor.distribution
Normal(loc: tensor([-0.0888], grad_fn=<AddBackward0>), scale: tensor([0.5000]))
class Critic(nn.Module):
def __init__(self, state_dim: int) -> None:
super().__init__()
self.mlp = MLP(
in_dim=state_dim,
out_dim=1
)
def forward(self, state: Tensor) -> Tensor:
return self.evaluate(state)
def evaluate(self, state: Tensor) -> Tensor:
out: Tensor = self.mlp(state)
value = out.squeeze(dim=-1)
return value
critic = Critic(state_dim)
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float)
value = critic(state)
print(f"estimated state value: {value}")
estimated state value: 0.1734076738357544
def compute_rewards_to_go(rewards: list[float], gamma: float) -> list[float]:
rewards = rewards.copy()
rewards_to_go = deque()
while len(rewards) > 0:
# Iterate the rewards from back to front
reward = rewards.pop()
# The value of reward-to-go
# associated with the next state
if len(rewards_to_go) > 0:
next_reward_to_go = rewards_to_go[0]
else:
next_reward_to_go = 0
# Compute reward-go-to associated with current state
reward_to_go = reward + gamma * next_reward_to_go
# Add to the front of the queue
rewards_to_go.appendleft(reward_to_go)
# Convert to list
rewards_to_go = list(rewards_to_go)
return rewards_to_go
compute_rewards_to_go([1, 10, 100], gamma=0.1)
[3.0, 20.0, 100.0]
def compute_advantages(
critic: Critic,
rewards_to_go: list[float],
states: list[np.ndarray]
):
# State values
states = torch.tensor(np.array(states), dtype=torch.float)
values = critic.evaluate(states)
values = values.detach().numpy()
# Compute advantages
rewards_to_go = np.array(rewards_to_go)
advantages: np.ndarray = rewards_to_go - values
# Normalize the advatages
advantages = stats.zscore(advantages)
return advantages
advantages = compute_advantages(
critic,
[1, 2, 5],
[state] * 3
)
print(f"advantages: {advantages}")
advantages: [-0.98058067 -0.39223227 1.37281295]
Transition Data#
Transition = namedtuple(
"Transition",
(
"state",
"action",
"reward",
"reward_to_go",
"log_prob"
)
)
def convert_to_transition_with_fields_as_lists(transitions: list[Transition]) -> Transition:
return Transition(*map(list, zip(*transitions)))
def convert_to_transitions(transition_with_fields_as_list: Transition) -> Transition:
return list(map(
lambda fields: Transition(*fields),
zip(*transition_with_fields_as_list)
))
def play_one_episode(
env: Env,
actor: Actor,
max_n_timesteps_per_episode: int,
gamma: float = 0.99
) -> Transition:
states = []
actions = []
rewards = []
log_probs = []
state, _ = env.reset()
is_done = False
for t in range(max_n_timesteps_per_episode):
# Select an action
state = torch.tensor(state, dtype=torch.float)
action = actor.select_action(state)
action = action.detach().numpy()
# Compute the log-probability of the action taken
log_prob = actor.log_prob()
log_prob = log_prob.detach().item()
# Interact with the env
next_state, reward, is_terminated, is_truncated, _ = env.step(action)
states.append(state)
actions.append(action)
rewards.append(reward)
log_probs.append(log_prob)
# The episode ends
is_done = is_terminated or is_truncated
if is_done:
break
# Step to the next state
state = next_state
# Compute the rewards-to-go
# based on the received rewards of entire episode
rewards_to_go = compute_rewards_to_go(rewards, gamma)
return Transition(
state=states,
action=actions,
reward=rewards,
reward_to_go=rewards_to_go,
log_prob=log_probs,
)
class ReplayBuffer(deque, Dataset):
def __init__(
self,
env: Env,
capacity: int,
max_n_timesteps_per_episode: int,
gamma: float
) -> None:
super().__init__(maxlen=capacity)
self._env = env
self._capacity = capacity
self._max_n_timesteps_per_episode = max_n_timesteps_per_episode
self._gamma = gamma
@property
def capacity(self) -> int:
return self._capacity
@property
def max_n_timesteps_per_episode(self) -> int:
return self._max_n_timesteps_per_episode
def collect(self, actor: Actor) -> None:
while len(self) < self._capacity:
transition_with_fields_as_list = play_one_episode(
env=self._env,
actor=actor,
max_n_timesteps_per_episode=self._max_n_timesteps_per_episode,
gamma=self._gamma
)
# Convert to list of transitions
transitions = convert_to_transitions(transition_with_fields_as_list)
# Add to the buffer
self.extend(transitions)
Updating Actor and Critic#
actor_opt = Adam(actor.parameters())
critic_opt = Adam(critic.parameters())
def update_actor_critic(
actor: Actor,
critic: Critic,
actor_opt: Optimizer,
critic_opt: Optimizer,
replay_buffer_loader: DataLoader,
n_epochs: int,
epsilon: float = 0.2
):
for _ in range(n_epochs):
transition: Transition
for transition in replay_buffer_loader:
batch_states = transition.state
batch_actions = transition.action
log_porbs = actor.log_prob(batch_actions, batch_states)
batch_log_probs = transition.log_prob
ratios = torch.exp(log_porbs - batch_log_probs)
batch_rewards_to_go = transition.reward_to_go.type(torch.float)
state_values = critic.evaluate(batch_states)
advantages = batch_rewards_to_go - state_values.detach()
surr1 = ratios * advantages
surr2 = torch.clip(
ratios,
1 - epsilon,
1 + epsilon
) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
# Update actor
actor_opt.zero_grad()
actor_loss.backward()
actor_opt.step()
critic_loss = F.mse_loss(batch_rewards_to_go, state_values)
# Update critic
critic_opt.zero_grad()
critic_loss.backward()
critic_opt.step()
@dataclass
class PPOConfig:
# Total number of epochs of the PPO algorithm
n_epochs: int
replay_buffer_capacity: int
max_n_timesteps_per_episode: int
# Discount factor
gamma: float
# Batch size of the data loader
batch_size: int
# Number of epochs to update actor and critc networks
n_epochs_for_actor_critic: int
# Learning rate of the Adam optimizers
lr: float
PPO Class#
class PPO:
def __init__(
self,
env: Env,
config: PPOConfig
) -> None:
self._env = env
self._config = config
state_dim = get_state_dim(env)
action_spec = ActionSpec.from_env(env)
# Actor and critic networks
self._actor = Actor(state_dim, action_spec)
self._critic = Critic(state_dim)
# Optimizers
self._actor_opt = Adam(self._actor.parameters(), lr=self._config.lr)
self._critic_opt = Adam(self._critic.parameters(), lr=self._config.lr)
# Replay buffer
self._replay_buffer = ReplayBuffer(
env=env,
capacity=self._config.replay_buffer_capacity,
max_n_timesteps_per_episode=self._config.max_n_timesteps_per_episode,
gamma=self._config.gamma
)
@property
def actor(self) -> Actor:
return self._actor
@property
def critic(self) -> Critic:
return self._critic
def learn(self):
for epoch in range(self._config.n_epochs):
logging.info(f"PPO epoch: {epoch + 1}")
avg_episode_rewards = []
# Collect transitions
self._replay_buffer.collect(self._actor)
rewards = []
for transition in self._replay_buffer:
rewards.append(transition.reward)
avg_episode_reward = np.mean(rewards)
avg_episode_rewards.append(avg_episode_reward)
logging.info(f"average episode rewards: {avg_episode_reward}")
# Create a data loader
replay_buffer_loader = DataLoader(
self._replay_buffer,
batch_size=self._config.batch_size
)
# Train the actor and critic
update_actor_critic(
actor=self._actor,
critic=self._critic,
actor_opt=self._actor_opt,
critic_opt=self._critic_opt,
replay_buffer_loader=replay_buffer_loader,
n_epochs=self._config.n_epochs_for_actor_critic
)
# Clear replay buffer
self._replay_buffer.clear()
Training#
ppo = PPO(
env,
config=PPOConfig(
n_epochs=100,
replay_buffer_capacity=4800,
max_n_timesteps_per_episode=1600,
gamma=0.99,
batch_size=1024,
n_epochs_for_actor_critic=10,
lr=0.01
)
)
ppo.learn()
2023-09-21 14:43:53,491 | INFO | PPO epoch: 1
2023-09-21 14:43:54,746 | INFO | average episode rewards: -6.110923935401077
2023-09-21 14:43:55,076 | INFO | PPO epoch: 2
2023-09-21 14:43:56,339 | INFO | average episode rewards: -6.071260639772841
2023-09-21 14:43:56,656 | INFO | PPO epoch: 3
2023-09-21 14:43:57,937 | INFO | average episode rewards: -6.286279557600787
2023-09-21 14:43:58,241 | INFO | PPO epoch: 4
2023-09-21 14:43:59,476 | INFO | average episode rewards: -5.504663757900611
2023-09-21 14:43:59,785 | INFO | PPO epoch: 5
2023-09-21 14:44:01,035 | INFO | average episode rewards: -6.153235834902657
2023-09-21 14:44:01,327 | INFO | PPO epoch: 6
2023-09-21 14:44:02,545 | INFO | average episode rewards: -6.730246981997852
2023-09-21 14:44:02,837 | INFO | PPO epoch: 7
2023-09-21 14:44:04,076 | INFO | average episode rewards: -5.464440702286453
2023-09-21 14:44:04,458 | INFO | PPO epoch: 8
2023-09-21 14:44:05,671 | INFO | average episode rewards: -5.593644874240346
2023-09-21 14:44:05,962 | INFO | PPO epoch: 9
2023-09-21 14:44:07,171 | INFO | average episode rewards: -6.399172619649428
2023-09-21 14:44:07,462 | INFO | PPO epoch: 10
2023-09-21 14:44:08,685 | INFO | average episode rewards: -6.022530136556304
2023-09-21 14:44:08,983 | INFO | PPO epoch: 11
2023-09-21 14:44:10,215 | INFO | average episode rewards: -6.050024077415373
2023-09-21 14:44:10,573 | INFO | PPO epoch: 12
2023-09-21 14:44:11,792 | INFO | average episode rewards: -5.595678804214014
2023-09-21 14:44:12,085 | INFO | PPO epoch: 13
2023-09-21 14:44:13,291 | INFO | average episode rewards: -5.7897880148753895
2023-09-21 14:44:13,583 | INFO | PPO epoch: 14
2023-09-21 14:44:14,799 | INFO | average episode rewards: -5.937452429329713
2023-09-21 14:44:15,093 | INFO | PPO epoch: 15
2023-09-21 14:44:16,314 | INFO | average episode rewards: -5.859606695803003
2023-09-21 14:44:16,672 | INFO | PPO epoch: 16
2023-09-21 14:44:17,887 | INFO | average episode rewards: -5.676640589709954
2023-09-21 14:44:18,195 | INFO | PPO epoch: 17
2023-09-21 14:44:19,419 | INFO | average episode rewards: -5.387011686028222
2023-09-21 14:44:19,744 | INFO | PPO epoch: 18
2023-09-21 14:44:20,956 | INFO | average episode rewards: -5.776525081077417
2023-09-21 14:44:21,245 | INFO | PPO epoch: 19
2023-09-21 14:44:22,450 | INFO | average episode rewards: -5.472712473588778
2023-09-21 14:44:22,805 | INFO | PPO epoch: 20
2023-09-21 14:44:24,005 | INFO | average episode rewards: -5.2392058947750675
2023-09-21 14:44:24,295 | INFO | PPO epoch: 21
2023-09-21 14:44:25,506 | INFO | average episode rewards: -6.096322011961469
2023-09-21 14:44:25,794 | INFO | PPO epoch: 22
2023-09-21 14:44:27,000 | INFO | average episode rewards: -5.723285307642029
2023-09-21 14:44:27,288 | INFO | PPO epoch: 23
2023-09-21 14:44:28,511 | INFO | average episode rewards: -5.781585879501909
2023-09-21 14:44:28,864 | INFO | PPO epoch: 24
2023-09-21 14:44:30,132 | INFO | average episode rewards: -5.404026350685467
2023-09-21 14:44:30,459 | INFO | PPO epoch: 25
2023-09-21 14:44:31,703 | INFO | average episode rewards: -5.59045952709884
2023-09-21 14:44:32,024 | INFO | PPO epoch: 26
2023-09-21 14:44:33,242 | INFO | average episode rewards: -5.630707363056575
2023-09-21 14:44:33,556 | INFO | PPO epoch: 27
2023-09-21 14:44:34,826 | INFO | average episode rewards: -5.025507696244329
2023-09-21 14:44:35,217 | INFO | PPO epoch: 28
2023-09-21 14:44:36,527 | INFO | average episode rewards: -5.972768697684134
2023-09-21 14:44:36,915 | INFO | PPO epoch: 29
2023-09-21 14:44:38,185 | INFO | average episode rewards: -5.367618255026877
2023-09-21 14:44:38,477 | INFO | PPO epoch: 30
2023-09-21 14:44:39,777 | INFO | average episode rewards: -5.140174826735197
2023-09-21 14:44:40,114 | INFO | PPO epoch: 31
2023-09-21 14:44:41,363 | INFO | average episode rewards: -4.981673487006734
2023-09-21 14:44:41,729 | INFO | PPO epoch: 32
2023-09-21 14:44:42,991 | INFO | average episode rewards: -5.127025520256155
2023-09-21 14:44:43,300 | INFO | PPO epoch: 33
2023-09-21 14:44:44,525 | INFO | average episode rewards: -5.3419723138993955
2023-09-21 14:44:44,813 | INFO | PPO epoch: 34
2023-09-21 14:44:46,031 | INFO | average episode rewards: -5.30891222853752
2023-09-21 14:44:46,318 | INFO | PPO epoch: 35
2023-09-21 14:44:47,529 | INFO | average episode rewards: -5.684426722799187
2023-09-21 14:44:47,885 | INFO | PPO epoch: 36
2023-09-21 14:44:49,138 | INFO | average episode rewards: -4.8530052231381
2023-09-21 14:44:49,477 | INFO | PPO epoch: 37
2023-09-21 14:44:50,708 | INFO | average episode rewards: -5.166183728959908
2023-09-21 14:44:50,999 | INFO | PPO epoch: 38
2023-09-21 14:44:52,218 | INFO | average episode rewards: -4.636682895422886
2023-09-21 14:44:52,507 | INFO | PPO epoch: 39
2023-09-21 14:44:53,753 | INFO | average episode rewards: -4.71075184322738
2023-09-21 14:44:54,224 | INFO | PPO epoch: 40
2023-09-21 14:44:55,455 | INFO | average episode rewards: -5.1569811687949425
2023-09-21 14:44:55,763 | INFO | PPO epoch: 41
2023-09-21 14:44:57,035 | INFO | average episode rewards: -4.79638571747526
2023-09-21 14:44:57,336 | INFO | PPO epoch: 42
2023-09-21 14:44:58,561 | INFO | average episode rewards: -4.835689811830781
2023-09-21 14:44:58,849 | INFO | PPO epoch: 43
2023-09-21 14:45:00,107 | INFO | average episode rewards: -4.282140943069902
2023-09-21 14:45:00,494 | INFO | PPO epoch: 44
2023-09-21 14:45:01,763 | INFO | average episode rewards: -4.510575011972921
2023-09-21 14:45:02,050 | INFO | PPO epoch: 45
2023-09-21 14:45:03,255 | INFO | average episode rewards: -4.3932612586966036
2023-09-21 14:45:03,540 | INFO | PPO epoch: 46
2023-09-21 14:45:04,737 | INFO | average episode rewards: -4.474887797060449
2023-09-21 14:45:05,021 | INFO | PPO epoch: 47
2023-09-21 14:45:06,216 | INFO | average episode rewards: -4.7754210740039325
2023-09-21 14:45:06,608 | INFO | PPO epoch: 48
2023-09-21 14:45:07,833 | INFO | average episode rewards: -4.3594883025611475
2023-09-21 14:45:08,129 | INFO | PPO epoch: 49
2023-09-21 14:45:09,351 | INFO | average episode rewards: -4.323283747093484
2023-09-21 14:45:09,641 | INFO | PPO epoch: 50
2023-09-21 14:45:10,858 | INFO | average episode rewards: -4.232465984200838
2023-09-21 14:45:11,172 | INFO | PPO epoch: 51
2023-09-21 14:45:12,394 | INFO | average episode rewards: -4.255824298900807
2023-09-21 14:45:12,744 | INFO | PPO epoch: 52
2023-09-21 14:45:14,001 | INFO | average episode rewards: -4.308886162348113
2023-09-21 14:45:14,401 | INFO | PPO epoch: 53
2023-09-21 14:45:15,710 | INFO | average episode rewards: -4.323119418666096
2023-09-21 14:45:16,027 | INFO | PPO epoch: 54
2023-09-21 14:45:17,249 | INFO | average episode rewards: -4.257481836218371
2023-09-21 14:45:17,541 | INFO | PPO epoch: 55
2023-09-21 14:45:18,771 | INFO | average episode rewards: -4.084458802543108
2023-09-21 14:45:19,127 | INFO | PPO epoch: 56
2023-09-21 14:45:20,362 | INFO | average episode rewards: -4.169530096280342
2023-09-21 14:45:20,654 | INFO | PPO epoch: 57
2023-09-21 14:45:21,872 | INFO | average episode rewards: -4.137685485908213
2023-09-21 14:45:22,162 | INFO | PPO epoch: 58
2023-09-21 14:45:23,372 | INFO | average episode rewards: -4.249553402858064
2023-09-21 14:45:23,664 | INFO | PPO epoch: 59
2023-09-21 14:45:24,873 | INFO | average episode rewards: -3.964422105153224
2023-09-21 14:45:25,227 | INFO | PPO epoch: 60
2023-09-21 14:45:26,440 | INFO | average episode rewards: -4.064252007168193
2023-09-21 14:45:26,730 | INFO | PPO epoch: 61
2023-09-21 14:45:27,941 | INFO | average episode rewards: -3.8129484918542738
2023-09-21 14:45:28,243 | INFO | PPO epoch: 62
2023-09-21 14:45:29,453 | INFO | average episode rewards: -4.10846290588343
2023-09-21 14:45:29,759 | INFO | PPO epoch: 63
2023-09-21 14:45:30,966 | INFO | average episode rewards: -4.100469735873232
2023-09-21 14:45:31,319 | INFO | PPO epoch: 64
2023-09-21 14:45:32,535 | INFO | average episode rewards: -3.988986004465981
2023-09-21 14:45:32,823 | INFO | PPO epoch: 65
2023-09-21 14:45:34,042 | INFO | average episode rewards: -4.172985700446049
2023-09-21 14:45:34,329 | INFO | PPO epoch: 66
2023-09-21 14:45:35,544 | INFO | average episode rewards: -3.9492351237683425
2023-09-21 14:45:35,834 | INFO | PPO epoch: 67
2023-09-21 14:45:37,055 | INFO | average episode rewards: -4.037028037636843
2023-09-21 14:45:37,408 | INFO | PPO epoch: 68
2023-09-21 14:45:38,632 | INFO | average episode rewards: -3.8717523354590213
2023-09-21 14:45:38,919 | INFO | PPO epoch: 69
2023-09-21 14:45:40,139 | INFO | average episode rewards: -3.6380695252078232
2023-09-21 14:45:40,427 | INFO | PPO epoch: 70
2023-09-21 14:45:41,644 | INFO | average episode rewards: -3.7535559594569334
2023-09-21 14:45:41,932 | INFO | PPO epoch: 71
2023-09-21 14:45:43,145 | INFO | average episode rewards: -3.717212148282896
2023-09-21 14:45:43,497 | INFO | PPO epoch: 72
2023-09-21 14:45:44,707 | INFO | average episode rewards: -3.8200672704970398
2023-09-21 14:45:44,997 | INFO | PPO epoch: 73
2023-09-21 14:45:46,217 | INFO | average episode rewards: -3.8056665582783458
2023-09-21 14:45:46,505 | INFO | PPO epoch: 74
2023-09-21 14:45:47,727 | INFO | average episode rewards: -3.671062882379162
2023-09-21 14:45:48,017 | INFO | PPO epoch: 75
2023-09-21 14:45:49,240 | INFO | average episode rewards: -3.6105846328232505
2023-09-21 14:45:49,601 | INFO | PPO epoch: 76
2023-09-21 14:45:50,822 | INFO | average episode rewards: -3.5088284661623703
2023-09-21 14:45:51,111 | INFO | PPO epoch: 77
2023-09-21 14:45:52,331 | INFO | average episode rewards: -3.463422440033662
2023-09-21 14:45:52,619 | INFO | PPO epoch: 78
2023-09-21 14:45:53,830 | INFO | average episode rewards: -3.390544654137758
2023-09-21 14:45:54,118 | INFO | PPO epoch: 79
2023-09-21 14:45:55,336 | INFO | average episode rewards: -3.5680100887619566
2023-09-21 14:45:55,688 | INFO | PPO epoch: 80
2023-09-21 14:45:56,900 | INFO | average episode rewards: -3.352304634466086
2023-09-21 14:45:57,190 | INFO | PPO epoch: 81
2023-09-21 14:45:58,413 | INFO | average episode rewards: -3.308571335618333
2023-09-21 14:45:58,704 | INFO | PPO epoch: 82
2023-09-21 14:45:59,933 | INFO | average episode rewards: -3.363160211521526
2023-09-21 14:46:00,222 | INFO | PPO epoch: 83
2023-09-21 14:46:01,427 | INFO | average episode rewards: -3.336297793707316
2023-09-21 14:46:01,779 | INFO | PPO epoch: 84
2023-09-21 14:46:03,000 | INFO | average episode rewards: -3.1232256799079816
2023-09-21 14:46:03,288 | INFO | PPO epoch: 85
2023-09-21 14:46:04,505 | INFO | average episode rewards: -2.893729036215459
2023-09-21 14:46:04,795 | INFO | PPO epoch: 86
2023-09-21 14:46:06,015 | INFO | average episode rewards: -2.7299993584978868
2023-09-21 14:46:06,303 | INFO | PPO epoch: 87
2023-09-21 14:46:07,514 | INFO | average episode rewards: -2.615799076306543
2023-09-21 14:46:07,869 | INFO | PPO epoch: 88
2023-09-21 14:46:09,112 | INFO | average episode rewards: -2.207611983904887
2023-09-21 14:46:09,409 | INFO | PPO epoch: 89
2023-09-21 14:46:10,650 | INFO | average episode rewards: -1.9580470846103188
2023-09-21 14:46:10,939 | INFO | PPO epoch: 90
2023-09-21 14:46:12,155 | INFO | average episode rewards: -1.6758821685748795
2023-09-21 14:46:12,446 | INFO | PPO epoch: 91
2023-09-21 14:46:13,662 | INFO | average episode rewards: -1.1631944322504268
2023-09-21 14:46:14,014 | INFO | PPO epoch: 92
2023-09-21 14:46:15,230 | INFO | average episode rewards: -1.1940382901557742
2023-09-21 14:46:15,519 | INFO | PPO epoch: 93
2023-09-21 14:46:16,730 | INFO | average episode rewards: -1.186452953358785
2023-09-21 14:46:17,020 | INFO | PPO epoch: 94
2023-09-21 14:46:18,232 | INFO | average episode rewards: -1.1377638725134356
2023-09-21 14:46:18,523 | INFO | PPO epoch: 95
2023-09-21 14:46:19,759 | INFO | average episode rewards: -1.2804482369598744
2023-09-21 14:46:20,116 | INFO | PPO epoch: 96
2023-09-21 14:46:21,333 | INFO | average episode rewards: -1.2031686884665314
2023-09-21 14:46:21,622 | INFO | PPO epoch: 97
2023-09-21 14:46:22,836 | INFO | average episode rewards: -1.217157130940351
2023-09-21 14:46:23,124 | INFO | PPO epoch: 98
2023-09-21 14:46:24,343 | INFO | average episode rewards: -1.0281718137516214
2023-09-21 14:46:24,632 | INFO | PPO epoch: 99
2023-09-21 14:46:25,864 | INFO | average episode rewards: -1.0543439617908212
2023-09-21 14:46:26,218 | INFO | PPO epoch: 100
2023-09-21 14:46:27,457 | INFO | average episode rewards: -0.9634774535826106
torch.save(ppo.actor, "../models/actor.pth")
torch.save(ppo.critic, "../models/critic.pth")
Loading Models#
actor = torch.load("../models/actor.pth")
critic = torch.load("../models/critic.pth")
actor.eval()
critic.eval()
print(actor)
print(critic)
Actor(
(mlp): MLP(
(layer1): Linear(in_features=3, out_features=64, bias=True)
(layer2): Linear(in_features=64, out_features=64, bias=True)
(layer3): Linear(in_features=64, out_features=1, bias=True)
)
(dist_param1): Sequential(
(0): Linear(in_features=1, out_features=16, bias=True)
(1): ReLU(inplace=True)
(2): Linear(in_features=16, out_features=1, bias=True)
(3): Sigmoid()
)
(dist_param2): Sequential(
(0): Linear(in_features=1, out_features=16, bias=True)
(1): ReLU(inplace=True)
(2): Linear(in_features=16, out_features=1, bias=True)
(3): Sigmoid()
)
)
Critic(
(mlp): MLP(
(layer1): Linear(in_features=3, out_features=64, bias=True)
(layer2): Linear(in_features=64, out_features=64, bias=True)
(layer3): Linear(in_features=64, out_features=1, bias=True)
)
)
Performance of the Trained Agent#
env = gym.make(
"Pendulum-v1",
max_episode_steps=1000,
render_mode="rgb_array"
)
env = RecordVideo(
env,
video_folder="../_static/videos",
name_prefix=f"ppo-{env.spec.id}"
)
state, _ = env.reset()
for t in count():
# Select an action
state = torch.tensor(state, dtype=torch.float)
action = actor.select_action(state)
action = action.detach().numpy()
# Interactive with the environment
next_state, reward, is_terminated, is_truncated, info = env.step(action)
is_done = is_terminated or is_truncated
if is_done:
break
# Step to the next state
state = next_state
env.close()
/opt/anaconda3/envs/linguaml/lib/python3.11/site-packages/gymnasium/wrappers/record_video.py:94: UserWarning: WARN: Overwriting existing videos at /Users/isaac/Developer/py-projects/linguAML/book/_static/videos folder (try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)
logger.warn(
Moviepy - Building video /Users/isaac/Developer/py-projects/linguAML/book/_static/videos/ppo-Pendulum-v1-episode-0.mp4.
Moviepy - Writing video /Users/isaac/Developer/py-projects/linguAML/book/_static/videos/ppo-Pendulum-v1-episode-0.mp4
Moviepy - Done !
Moviepy - video ready /Users/isaac/Developer/py-projects/linguAML/book/_static/videos/ppo-Pendulum-v1-episode-0.mp4
Video("../_static/videos/ppo-Pendulum-v1-episode-0.mp4")
