PPO#

from typing import Self, Optional
from collections import namedtuple, deque
from dataclasses import dataclass
from itertools import count
import logging
import numpy as np
from scipy import stats

import gymnasium as gym
from gymnasium import Env
from gymnasium.wrappers.record_video import RecordVideo

import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F
from torch.distributions import Distribution, Normal
from torch.optim import Optimizer, Adam

from IPython.display import Video
torch.manual_seed(42)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s"
)

env = gym.make("Pendulum-v1")

PPO Algorithm#

The following snapshot shows the pseudocode of the PPO-Clip algorithm:

../_images/ppo-clip.png

Environment Specifications#

def get_state_dim(env: Env) -> Optional[int]:
    
    try:
        return env.observation_space.shape[0]
    except:
        return None

def get_action_dim(env: Env) -> Optional[int]:
    
    try:
        return env.action_space.shape[0]
    except:
        return None
    
print(f"state space dim: {get_state_dim(env)}")
print(f"action space dim: {get_action_dim(env)}")
state space dim: 3
action space dim: 1
@dataclass
class ActionSpec:
    dim: int
    min: np.ndarray
    max: np.ndarray
    
    @classmethod
    def from_env(cls, env: Env) -> Self:
        return cls(
            dim=env.action_space.shape[0],
            min=env.action_space.__dict__["low"],
            max=env.action_space.__dict__["high"]
        )

state_dim = get_state_dim(env)
action_spec = ActionSpec.from_env(env)

print(f"state dim: {state_dim}")
print(f"action spec: {action_spec}")
state dim: 3
action spec: ActionSpec(dim=1, min=array([-2.], dtype=float32), max=array([2.], dtype=float32))

Actor and Critic Models#

class MLP(nn.Module):
    
    def __init__(
            self,
            in_dim: int,
            out_dim: int
        ) -> None:
        
        super().__init__()
        
        self.layer1 = nn.Linear(in_dim, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, out_dim)
    
    def forward(self, state: Tensor) -> Tensor:
        
        x = self.layer1(state)
        x = F.relu(x)
        x = self.layer2(x)
        x = F.relu(x)
        x = self.layer3(x)
        
        return x
class Actor(nn.Module):
    
    def __init__(
            self,
            state_dim: int,
            action_spec: ActionSpec
        ) -> None:
        
        super().__init__()
        
        
        self._action_min = torch.tensor(action_spec.min, dtype=torch.float)
        self._action_max = torch.tensor(action_spec.max, dtype=torch.float)
        
        self._distribution = None
        self._action = None
        
        self.mlp = MLP(
            in_dim=state_dim,
            out_dim=action_spec.dim
        )
        
        self.dist_param1 = nn.Sequential(
            nn.Linear(action_spec.dim, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, action_spec.dim),
            nn.Sigmoid()
        )
        
        self.dist_param2 = nn.Sequential(
            nn.Linear(action_spec.dim, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, action_spec.dim),
            nn.Sigmoid()
        )
        
    @property
    def distribution(self) -> Optional[Distribution]:
        
        return self._distribution
        
    
    def forward(self, state: Tensor) -> Tensor:
        
        out: Tensor = self.mlp(state)
    
        param1 = self.dist_param1(out)
        param2 = self.dist_param2(out)
        
        mean = F.sigmoid(out) * (self._action_max - self._action_min) + self._action_min
        # mean = param1 * (self._action_max - self._action_min) + self._action_min
        # std = param2
        std = torch.ones_like(mean) * 0.5
        
        # mean = F.sigmoid(out[..., 0]) * (self._action_max - self._action_min) + self._action_min
        # std = torch.ones_like(mean) * 0.5
        
        # Generate the distribution
        distribution = Normal(mean, std)
        
        # Store the distribution
        self._distribution = distribution
        
        return distribution
    
    def select_action(
            self, 
            state: Optional[Tensor] = None
        ) -> Tensor:
        
        if state is not None:
            self.forward(state)

        # Sample an action from the distribution
        action = self._distribution.sample()
        
        # Clip the action value by upper and lower bounds
        action = action.clip(self._action_min, self._action_max)
        
        # Store the action taken
        self._action = action
        
        return action
    
    def log_prob(
            self, 
            action: Optional[Tensor] = None,
            state: Optional[Tensor] = None
        ) -> Tensor:
        
        if action is None:
            assert state is None,\
                "the state must be set None since the action is None"
            return self._distribution.log_prob(self._action).sum(dim=-1)
        
        if state is not None:
            self.forward(state)
            
        return self._distribution.log_prob(action).sum(dim=-1)

actor = Actor(state_dim, action_spec)
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float)
action = actor.select_action(state)
log_prob = actor.log_prob()

print(f"action: {action}")
print(f"log-probability: {log_prob}")
action: tensor([0.4161])
log-probability: -0.7357034683227539
actor.distribution
Normal(loc: tensor([-0.0888], grad_fn=<AddBackward0>), scale: tensor([0.5000]))
class Critic(nn.Module):
    
    def __init__(self, state_dim: int) -> None:
        
        super().__init__()
        
        
        self.mlp = MLP(
            in_dim=state_dim,
            out_dim=1
        )

    def forward(self, state: Tensor) -> Tensor:
        
        return self.evaluate(state)
    
    def evaluate(self, state: Tensor) -> Tensor:
        
        out: Tensor = self.mlp(state)
        value = out.squeeze(dim=-1)
        
        return value

critic = Critic(state_dim)
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float)
value = critic(state)

print(f"estimated state value: {value}")
estimated state value: 0.1734076738357544
def compute_rewards_to_go(rewards: list[float], gamma: float) -> list[float]:
    
    rewards = rewards.copy()
    rewards_to_go = deque()
    
    while len(rewards) > 0:
        # Iterate the rewards from back to front
        reward = rewards.pop()
        
        # The value of reward-to-go 
        # associated with the next state
        if len(rewards_to_go) > 0:
            next_reward_to_go = rewards_to_go[0]
        else:
            next_reward_to_go = 0
        
        # Compute reward-go-to associated with current state
        reward_to_go = reward + gamma * next_reward_to_go
        
        # Add to the front of the queue
        rewards_to_go.appendleft(reward_to_go)
    
    # Convert to list
    rewards_to_go = list(rewards_to_go)
    
    return rewards_to_go

compute_rewards_to_go([1, 10, 100], gamma=0.1)
[3.0, 20.0, 100.0]
def compute_advantages(
        critic: Critic,
        rewards_to_go: list[float], 
        states: list[np.ndarray]
    ):
    
    # State values
    states = torch.tensor(np.array(states), dtype=torch.float)
    values = critic.evaluate(states)
    values = values.detach().numpy()
    
    # Compute advantages
    rewards_to_go = np.array(rewards_to_go)
    advantages: np.ndarray = rewards_to_go - values
    
    # Normalize the advatages
    advantages = stats.zscore(advantages)
    
    return advantages

advantages = compute_advantages(
    critic,
    [1, 2, 5],
    [state] * 3
)

print(f"advantages: {advantages}")
advantages: [-0.98058067 -0.39223227  1.37281295]

Transition Data#

Transition = namedtuple(
    "Transition",
    (
        "state",
        "action",
        "reward",
        "reward_to_go",
        "log_prob"
    )
)

def convert_to_transition_with_fields_as_lists(transitions: list[Transition]) -> Transition:
    
    return Transition(*map(list, zip(*transitions)))
    
def convert_to_transitions(transition_with_fields_as_list: Transition) -> Transition:
    
    return list(map(
        lambda fields: Transition(*fields), 
        zip(*transition_with_fields_as_list)
    ))
def play_one_episode(
        env: Env,
        actor: Actor,
        max_n_timesteps_per_episode: int,
        gamma: float = 0.99
    ) -> Transition:
    
    states = []
    actions = []
    rewards = []
    log_probs = []

    state, _ = env.reset()
    is_done = False
    
    for t in range(max_n_timesteps_per_episode):
        
        # Select an action
        state = torch.tensor(state, dtype=torch.float)
        action = actor.select_action(state)
        action = action.detach().numpy()
        
        # Compute the log-probability of the action taken
        log_prob = actor.log_prob()
        log_prob = log_prob.detach().item()
        
        # Interact with the env
        next_state, reward, is_terminated, is_truncated, _ = env.step(action)
        
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        log_probs.append(log_prob)
        
        # The episode ends
        is_done = is_terminated or is_truncated
        if is_done:
            break
        
        # Step to the next state
        state = next_state
    
    # Compute the rewards-to-go 
    # based on the received rewards of entire episode
    rewards_to_go = compute_rewards_to_go(rewards, gamma)
    
    return Transition(
        state=states,
        action=actions,
        reward=rewards,
        reward_to_go=rewards_to_go,
        log_prob=log_probs,
    )
class ReplayBuffer(deque, Dataset):
    
    def __init__(
            self,
            env: Env,
            capacity: int,
            max_n_timesteps_per_episode: int,
            gamma: float
        ) -> None:
        
        super().__init__(maxlen=capacity)

        self._env = env
        self._capacity = capacity
        self._max_n_timesteps_per_episode = max_n_timesteps_per_episode
        self._gamma = gamma
    
    @property
    def capacity(self) -> int:
        return self._capacity
    
    @property
    def max_n_timesteps_per_episode(self) -> int:
        return self._max_n_timesteps_per_episode
        
    def collect(self, actor: Actor) -> None:
        
        while len(self) < self._capacity:
            
            transition_with_fields_as_list = play_one_episode(
                env=self._env,
                actor=actor,
                max_n_timesteps_per_episode=self._max_n_timesteps_per_episode,
                gamma=self._gamma
            )
            
            # Convert to list of transitions
            transitions = convert_to_transitions(transition_with_fields_as_list)
            
            # Add to the buffer
            self.extend(transitions)

Updating Actor and Critic#

actor_opt = Adam(actor.parameters())
critic_opt = Adam(critic.parameters())
def update_actor_critic(
        actor: Actor,
        critic: Critic,
        actor_opt: Optimizer,
        critic_opt: Optimizer,
        replay_buffer_loader: DataLoader,
        n_epochs: int,
        epsilon: float = 0.2
    ):
    
    for _ in range(n_epochs):
        transition: Transition
        for transition in replay_buffer_loader:
            
            batch_states = transition.state
            batch_actions = transition.action
            
            log_porbs = actor.log_prob(batch_actions, batch_states)
            batch_log_probs = transition.log_prob
            
            ratios = torch.exp(log_porbs - batch_log_probs)
            
            batch_rewards_to_go = transition.reward_to_go.type(torch.float)
            state_values = critic.evaluate(batch_states)
            advantages = batch_rewards_to_go - state_values.detach()
        
            surr1 = ratios * advantages
            surr2 = torch.clip(
                ratios,
                1 - epsilon,
                1 + epsilon
            ) * advantages
            
            actor_loss = -torch.min(surr1, surr2).mean()
            
            # Update actor
            actor_opt.zero_grad()
            actor_loss.backward()
            actor_opt.step()
            
            
            critic_loss = F.mse_loss(batch_rewards_to_go, state_values)

            # Update critic
            critic_opt.zero_grad()
            critic_loss.backward()
            critic_opt.step()
@dataclass
class PPOConfig:
    
    # Total number of epochs of the PPO algorithm
    n_epochs: int
    
    replay_buffer_capacity: int
    max_n_timesteps_per_episode: int
    
    # Discount factor
    gamma: float
    
    # Batch size of the data loader
    batch_size: int
    
    # Number of epochs to update actor and critc networks
    n_epochs_for_actor_critic: int
    
    # Learning rate of the Adam optimizers
    lr: float

PPO Class#

class PPO:
    
    def __init__(
            self, 
            env: Env,
            config: PPOConfig
        ) -> None:
        
        self._env = env
        self._config = config
        state_dim = get_state_dim(env)
        action_spec = ActionSpec.from_env(env)
        
        # Actor and critic networks
        self._actor = Actor(state_dim, action_spec)
        self._critic = Critic(state_dim)
        
        # Optimizers
        self._actor_opt = Adam(self._actor.parameters(), lr=self._config.lr)
        self._critic_opt = Adam(self._critic.parameters(), lr=self._config.lr)
        
        # Replay buffer
        self._replay_buffer = ReplayBuffer(
            env=env,
            capacity=self._config.replay_buffer_capacity,
            max_n_timesteps_per_episode=self._config.max_n_timesteps_per_episode,
            gamma=self._config.gamma
        )
    
    @property
    def actor(self) -> Actor:
        return self._actor
    
    @property
    def critic(self) -> Critic:
        return self._critic
    
    def learn(self):
        
        for epoch in range(self._config.n_epochs):
            
            logging.info(f"PPO epoch: {epoch + 1}")
            avg_episode_rewards = []
            
            # Collect transitions
            self._replay_buffer.collect(self._actor)
            
            rewards = []
            for transition in self._replay_buffer:
                rewards.append(transition.reward)
            avg_episode_reward = np.mean(rewards)
            avg_episode_rewards.append(avg_episode_reward)
            logging.info(f"average episode rewards: {avg_episode_reward}")
            
            # Create a data loader
            replay_buffer_loader = DataLoader(
                self._replay_buffer,
                batch_size=self._config.batch_size
            )
            
            # Train the actor and critic   
            update_actor_critic(
                actor=self._actor,
                critic=self._critic,
                actor_opt=self._actor_opt,
                critic_opt=self._critic_opt,
                replay_buffer_loader=replay_buffer_loader,
                n_epochs=self._config.n_epochs_for_actor_critic
            )
            
            # Clear replay buffer
            self._replay_buffer.clear()

Training#

ppo = PPO(
    env,
    config=PPOConfig(
        n_epochs=100,
        replay_buffer_capacity=4800,
        max_n_timesteps_per_episode=1600,
        gamma=0.99,
        batch_size=1024,
        n_epochs_for_actor_critic=10,
        lr=0.01
    )
)

ppo.learn()
2023-09-21 14:43:53,491 | INFO | PPO epoch: 1
2023-09-21 14:43:54,746 | INFO | average episode rewards: -6.110923935401077
2023-09-21 14:43:55,076 | INFO | PPO epoch: 2
2023-09-21 14:43:56,339 | INFO | average episode rewards: -6.071260639772841
2023-09-21 14:43:56,656 | INFO | PPO epoch: 3
2023-09-21 14:43:57,937 | INFO | average episode rewards: -6.286279557600787
2023-09-21 14:43:58,241 | INFO | PPO epoch: 4
2023-09-21 14:43:59,476 | INFO | average episode rewards: -5.504663757900611
2023-09-21 14:43:59,785 | INFO | PPO epoch: 5
2023-09-21 14:44:01,035 | INFO | average episode rewards: -6.153235834902657
2023-09-21 14:44:01,327 | INFO | PPO epoch: 6
2023-09-21 14:44:02,545 | INFO | average episode rewards: -6.730246981997852
2023-09-21 14:44:02,837 | INFO | PPO epoch: 7
2023-09-21 14:44:04,076 | INFO | average episode rewards: -5.464440702286453
2023-09-21 14:44:04,458 | INFO | PPO epoch: 8
2023-09-21 14:44:05,671 | INFO | average episode rewards: -5.593644874240346
2023-09-21 14:44:05,962 | INFO | PPO epoch: 9
2023-09-21 14:44:07,171 | INFO | average episode rewards: -6.399172619649428
2023-09-21 14:44:07,462 | INFO | PPO epoch: 10
2023-09-21 14:44:08,685 | INFO | average episode rewards: -6.022530136556304
2023-09-21 14:44:08,983 | INFO | PPO epoch: 11
2023-09-21 14:44:10,215 | INFO | average episode rewards: -6.050024077415373
2023-09-21 14:44:10,573 | INFO | PPO epoch: 12
2023-09-21 14:44:11,792 | INFO | average episode rewards: -5.595678804214014
2023-09-21 14:44:12,085 | INFO | PPO epoch: 13
2023-09-21 14:44:13,291 | INFO | average episode rewards: -5.7897880148753895
2023-09-21 14:44:13,583 | INFO | PPO epoch: 14
2023-09-21 14:44:14,799 | INFO | average episode rewards: -5.937452429329713
2023-09-21 14:44:15,093 | INFO | PPO epoch: 15
2023-09-21 14:44:16,314 | INFO | average episode rewards: -5.859606695803003
2023-09-21 14:44:16,672 | INFO | PPO epoch: 16
2023-09-21 14:44:17,887 | INFO | average episode rewards: -5.676640589709954
2023-09-21 14:44:18,195 | INFO | PPO epoch: 17
2023-09-21 14:44:19,419 | INFO | average episode rewards: -5.387011686028222
2023-09-21 14:44:19,744 | INFO | PPO epoch: 18
2023-09-21 14:44:20,956 | INFO | average episode rewards: -5.776525081077417
2023-09-21 14:44:21,245 | INFO | PPO epoch: 19
2023-09-21 14:44:22,450 | INFO | average episode rewards: -5.472712473588778
2023-09-21 14:44:22,805 | INFO | PPO epoch: 20
2023-09-21 14:44:24,005 | INFO | average episode rewards: -5.2392058947750675
2023-09-21 14:44:24,295 | INFO | PPO epoch: 21
2023-09-21 14:44:25,506 | INFO | average episode rewards: -6.096322011961469
2023-09-21 14:44:25,794 | INFO | PPO epoch: 22
2023-09-21 14:44:27,000 | INFO | average episode rewards: -5.723285307642029
2023-09-21 14:44:27,288 | INFO | PPO epoch: 23
2023-09-21 14:44:28,511 | INFO | average episode rewards: -5.781585879501909
2023-09-21 14:44:28,864 | INFO | PPO epoch: 24
2023-09-21 14:44:30,132 | INFO | average episode rewards: -5.404026350685467
2023-09-21 14:44:30,459 | INFO | PPO epoch: 25
2023-09-21 14:44:31,703 | INFO | average episode rewards: -5.59045952709884
2023-09-21 14:44:32,024 | INFO | PPO epoch: 26
2023-09-21 14:44:33,242 | INFO | average episode rewards: -5.630707363056575
2023-09-21 14:44:33,556 | INFO | PPO epoch: 27
2023-09-21 14:44:34,826 | INFO | average episode rewards: -5.025507696244329
2023-09-21 14:44:35,217 | INFO | PPO epoch: 28
2023-09-21 14:44:36,527 | INFO | average episode rewards: -5.972768697684134
2023-09-21 14:44:36,915 | INFO | PPO epoch: 29
2023-09-21 14:44:38,185 | INFO | average episode rewards: -5.367618255026877
2023-09-21 14:44:38,477 | INFO | PPO epoch: 30
2023-09-21 14:44:39,777 | INFO | average episode rewards: -5.140174826735197
2023-09-21 14:44:40,114 | INFO | PPO epoch: 31
2023-09-21 14:44:41,363 | INFO | average episode rewards: -4.981673487006734
2023-09-21 14:44:41,729 | INFO | PPO epoch: 32
2023-09-21 14:44:42,991 | INFO | average episode rewards: -5.127025520256155
2023-09-21 14:44:43,300 | INFO | PPO epoch: 33
2023-09-21 14:44:44,525 | INFO | average episode rewards: -5.3419723138993955
2023-09-21 14:44:44,813 | INFO | PPO epoch: 34
2023-09-21 14:44:46,031 | INFO | average episode rewards: -5.30891222853752
2023-09-21 14:44:46,318 | INFO | PPO epoch: 35
2023-09-21 14:44:47,529 | INFO | average episode rewards: -5.684426722799187
2023-09-21 14:44:47,885 | INFO | PPO epoch: 36
2023-09-21 14:44:49,138 | INFO | average episode rewards: -4.8530052231381
2023-09-21 14:44:49,477 | INFO | PPO epoch: 37
2023-09-21 14:44:50,708 | INFO | average episode rewards: -5.166183728959908
2023-09-21 14:44:50,999 | INFO | PPO epoch: 38
2023-09-21 14:44:52,218 | INFO | average episode rewards: -4.636682895422886
2023-09-21 14:44:52,507 | INFO | PPO epoch: 39
2023-09-21 14:44:53,753 | INFO | average episode rewards: -4.71075184322738
2023-09-21 14:44:54,224 | INFO | PPO epoch: 40
2023-09-21 14:44:55,455 | INFO | average episode rewards: -5.1569811687949425
2023-09-21 14:44:55,763 | INFO | PPO epoch: 41
2023-09-21 14:44:57,035 | INFO | average episode rewards: -4.79638571747526
2023-09-21 14:44:57,336 | INFO | PPO epoch: 42
2023-09-21 14:44:58,561 | INFO | average episode rewards: -4.835689811830781
2023-09-21 14:44:58,849 | INFO | PPO epoch: 43
2023-09-21 14:45:00,107 | INFO | average episode rewards: -4.282140943069902
2023-09-21 14:45:00,494 | INFO | PPO epoch: 44
2023-09-21 14:45:01,763 | INFO | average episode rewards: -4.510575011972921
2023-09-21 14:45:02,050 | INFO | PPO epoch: 45
2023-09-21 14:45:03,255 | INFO | average episode rewards: -4.3932612586966036
2023-09-21 14:45:03,540 | INFO | PPO epoch: 46
2023-09-21 14:45:04,737 | INFO | average episode rewards: -4.474887797060449
2023-09-21 14:45:05,021 | INFO | PPO epoch: 47
2023-09-21 14:45:06,216 | INFO | average episode rewards: -4.7754210740039325
2023-09-21 14:45:06,608 | INFO | PPO epoch: 48
2023-09-21 14:45:07,833 | INFO | average episode rewards: -4.3594883025611475
2023-09-21 14:45:08,129 | INFO | PPO epoch: 49
2023-09-21 14:45:09,351 | INFO | average episode rewards: -4.323283747093484
2023-09-21 14:45:09,641 | INFO | PPO epoch: 50
2023-09-21 14:45:10,858 | INFO | average episode rewards: -4.232465984200838
2023-09-21 14:45:11,172 | INFO | PPO epoch: 51
2023-09-21 14:45:12,394 | INFO | average episode rewards: -4.255824298900807
2023-09-21 14:45:12,744 | INFO | PPO epoch: 52
2023-09-21 14:45:14,001 | INFO | average episode rewards: -4.308886162348113
2023-09-21 14:45:14,401 | INFO | PPO epoch: 53
2023-09-21 14:45:15,710 | INFO | average episode rewards: -4.323119418666096
2023-09-21 14:45:16,027 | INFO | PPO epoch: 54
2023-09-21 14:45:17,249 | INFO | average episode rewards: -4.257481836218371
2023-09-21 14:45:17,541 | INFO | PPO epoch: 55
2023-09-21 14:45:18,771 | INFO | average episode rewards: -4.084458802543108
2023-09-21 14:45:19,127 | INFO | PPO epoch: 56
2023-09-21 14:45:20,362 | INFO | average episode rewards: -4.169530096280342
2023-09-21 14:45:20,654 | INFO | PPO epoch: 57
2023-09-21 14:45:21,872 | INFO | average episode rewards: -4.137685485908213
2023-09-21 14:45:22,162 | INFO | PPO epoch: 58
2023-09-21 14:45:23,372 | INFO | average episode rewards: -4.249553402858064
2023-09-21 14:45:23,664 | INFO | PPO epoch: 59
2023-09-21 14:45:24,873 | INFO | average episode rewards: -3.964422105153224
2023-09-21 14:45:25,227 | INFO | PPO epoch: 60
2023-09-21 14:45:26,440 | INFO | average episode rewards: -4.064252007168193
2023-09-21 14:45:26,730 | INFO | PPO epoch: 61
2023-09-21 14:45:27,941 | INFO | average episode rewards: -3.8129484918542738
2023-09-21 14:45:28,243 | INFO | PPO epoch: 62
2023-09-21 14:45:29,453 | INFO | average episode rewards: -4.10846290588343
2023-09-21 14:45:29,759 | INFO | PPO epoch: 63
2023-09-21 14:45:30,966 | INFO | average episode rewards: -4.100469735873232
2023-09-21 14:45:31,319 | INFO | PPO epoch: 64
2023-09-21 14:45:32,535 | INFO | average episode rewards: -3.988986004465981
2023-09-21 14:45:32,823 | INFO | PPO epoch: 65
2023-09-21 14:45:34,042 | INFO | average episode rewards: -4.172985700446049
2023-09-21 14:45:34,329 | INFO | PPO epoch: 66
2023-09-21 14:45:35,544 | INFO | average episode rewards: -3.9492351237683425
2023-09-21 14:45:35,834 | INFO | PPO epoch: 67
2023-09-21 14:45:37,055 | INFO | average episode rewards: -4.037028037636843
2023-09-21 14:45:37,408 | INFO | PPO epoch: 68
2023-09-21 14:45:38,632 | INFO | average episode rewards: -3.8717523354590213
2023-09-21 14:45:38,919 | INFO | PPO epoch: 69
2023-09-21 14:45:40,139 | INFO | average episode rewards: -3.6380695252078232
2023-09-21 14:45:40,427 | INFO | PPO epoch: 70
2023-09-21 14:45:41,644 | INFO | average episode rewards: -3.7535559594569334
2023-09-21 14:45:41,932 | INFO | PPO epoch: 71
2023-09-21 14:45:43,145 | INFO | average episode rewards: -3.717212148282896
2023-09-21 14:45:43,497 | INFO | PPO epoch: 72
2023-09-21 14:45:44,707 | INFO | average episode rewards: -3.8200672704970398
2023-09-21 14:45:44,997 | INFO | PPO epoch: 73
2023-09-21 14:45:46,217 | INFO | average episode rewards: -3.8056665582783458
2023-09-21 14:45:46,505 | INFO | PPO epoch: 74
2023-09-21 14:45:47,727 | INFO | average episode rewards: -3.671062882379162
2023-09-21 14:45:48,017 | INFO | PPO epoch: 75
2023-09-21 14:45:49,240 | INFO | average episode rewards: -3.6105846328232505
2023-09-21 14:45:49,601 | INFO | PPO epoch: 76
2023-09-21 14:45:50,822 | INFO | average episode rewards: -3.5088284661623703
2023-09-21 14:45:51,111 | INFO | PPO epoch: 77
2023-09-21 14:45:52,331 | INFO | average episode rewards: -3.463422440033662
2023-09-21 14:45:52,619 | INFO | PPO epoch: 78
2023-09-21 14:45:53,830 | INFO | average episode rewards: -3.390544654137758
2023-09-21 14:45:54,118 | INFO | PPO epoch: 79
2023-09-21 14:45:55,336 | INFO | average episode rewards: -3.5680100887619566
2023-09-21 14:45:55,688 | INFO | PPO epoch: 80
2023-09-21 14:45:56,900 | INFO | average episode rewards: -3.352304634466086
2023-09-21 14:45:57,190 | INFO | PPO epoch: 81
2023-09-21 14:45:58,413 | INFO | average episode rewards: -3.308571335618333
2023-09-21 14:45:58,704 | INFO | PPO epoch: 82
2023-09-21 14:45:59,933 | INFO | average episode rewards: -3.363160211521526
2023-09-21 14:46:00,222 | INFO | PPO epoch: 83
2023-09-21 14:46:01,427 | INFO | average episode rewards: -3.336297793707316
2023-09-21 14:46:01,779 | INFO | PPO epoch: 84
2023-09-21 14:46:03,000 | INFO | average episode rewards: -3.1232256799079816
2023-09-21 14:46:03,288 | INFO | PPO epoch: 85
2023-09-21 14:46:04,505 | INFO | average episode rewards: -2.893729036215459
2023-09-21 14:46:04,795 | INFO | PPO epoch: 86
2023-09-21 14:46:06,015 | INFO | average episode rewards: -2.7299993584978868
2023-09-21 14:46:06,303 | INFO | PPO epoch: 87
2023-09-21 14:46:07,514 | INFO | average episode rewards: -2.615799076306543
2023-09-21 14:46:07,869 | INFO | PPO epoch: 88
2023-09-21 14:46:09,112 | INFO | average episode rewards: -2.207611983904887
2023-09-21 14:46:09,409 | INFO | PPO epoch: 89
2023-09-21 14:46:10,650 | INFO | average episode rewards: -1.9580470846103188
2023-09-21 14:46:10,939 | INFO | PPO epoch: 90
2023-09-21 14:46:12,155 | INFO | average episode rewards: -1.6758821685748795
2023-09-21 14:46:12,446 | INFO | PPO epoch: 91
2023-09-21 14:46:13,662 | INFO | average episode rewards: -1.1631944322504268
2023-09-21 14:46:14,014 | INFO | PPO epoch: 92
2023-09-21 14:46:15,230 | INFO | average episode rewards: -1.1940382901557742
2023-09-21 14:46:15,519 | INFO | PPO epoch: 93
2023-09-21 14:46:16,730 | INFO | average episode rewards: -1.186452953358785
2023-09-21 14:46:17,020 | INFO | PPO epoch: 94
2023-09-21 14:46:18,232 | INFO | average episode rewards: -1.1377638725134356
2023-09-21 14:46:18,523 | INFO | PPO epoch: 95
2023-09-21 14:46:19,759 | INFO | average episode rewards: -1.2804482369598744
2023-09-21 14:46:20,116 | INFO | PPO epoch: 96
2023-09-21 14:46:21,333 | INFO | average episode rewards: -1.2031686884665314
2023-09-21 14:46:21,622 | INFO | PPO epoch: 97
2023-09-21 14:46:22,836 | INFO | average episode rewards: -1.217157130940351
2023-09-21 14:46:23,124 | INFO | PPO epoch: 98
2023-09-21 14:46:24,343 | INFO | average episode rewards: -1.0281718137516214
2023-09-21 14:46:24,632 | INFO | PPO epoch: 99
2023-09-21 14:46:25,864 | INFO | average episode rewards: -1.0543439617908212
2023-09-21 14:46:26,218 | INFO | PPO epoch: 100
2023-09-21 14:46:27,457 | INFO | average episode rewards: -0.9634774535826106
torch.save(ppo.actor, "../models/actor.pth")
torch.save(ppo.critic, "../models/critic.pth")

Loading Models#

actor = torch.load("../models/actor.pth")
critic = torch.load("../models/critic.pth")
actor.eval()
critic.eval()

print(actor)
print(critic)
Actor(
  (mlp): MLP(
    (layer1): Linear(in_features=3, out_features=64, bias=True)
    (layer2): Linear(in_features=64, out_features=64, bias=True)
    (layer3): Linear(in_features=64, out_features=1, bias=True)
  )
  (dist_param1): Sequential(
    (0): Linear(in_features=1, out_features=16, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=16, out_features=1, bias=True)
    (3): Sigmoid()
  )
  (dist_param2): Sequential(
    (0): Linear(in_features=1, out_features=16, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=16, out_features=1, bias=True)
    (3): Sigmoid()
  )
)
Critic(
  (mlp): MLP(
    (layer1): Linear(in_features=3, out_features=64, bias=True)
    (layer2): Linear(in_features=64, out_features=64, bias=True)
    (layer3): Linear(in_features=64, out_features=1, bias=True)
  )
)

Performance of the Trained Agent#

env = gym.make(
    "Pendulum-v1", 
    max_episode_steps=1000,
    render_mode="rgb_array"
)

env = RecordVideo(
    env, 
    video_folder="../_static/videos",
    name_prefix=f"ppo-{env.spec.id}"
)

state, _ = env.reset()

for t in count():
    
    # Select an action
    state = torch.tensor(state, dtype=torch.float)
    action = actor.select_action(state)
    action = action.detach().numpy()
    
    # Interactive with the environment
    next_state, reward, is_terminated, is_truncated, info = env.step(action)
    
    is_done = is_terminated or is_truncated
    
    if is_done:
        break
    
     # Step to the next state
    state = next_state

env.close()
/opt/anaconda3/envs/linguaml/lib/python3.11/site-packages/gymnasium/wrappers/record_video.py:94: UserWarning: WARN: Overwriting existing videos at /Users/isaac/Developer/py-projects/linguAML/book/_static/videos folder (try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)
  logger.warn(
Moviepy - Building video /Users/isaac/Developer/py-projects/linguAML/book/_static/videos/ppo-Pendulum-v1-episode-0.mp4.
Moviepy - Writing video /Users/isaac/Developer/py-projects/linguAML/book/_static/videos/ppo-Pendulum-v1-episode-0.mp4
                                                                
Moviepy - Done !
Moviepy - video ready /Users/isaac/Developer/py-projects/linguAML/book/_static/videos/ppo-Pendulum-v1-episode-0.mp4
Video("../_static/videos/ppo-Pendulum-v1-episode-0.mp4")