# PPO

In [84]:
from typing import Self, Optional
from collections import namedtuple, deque
from dataclasses import dataclass
from itertools import count
import logging
import numpy as np
from scipy import stats

import gymnasium as gym
from gymnasium import Env
from gymnasium.wrappers.record_video import RecordVideo

import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F
from torch.distributions import Distribution, Normal
from torch.optim import Optimizer, Adam

from IPython.display import Video

In [85]:
torch.manual_seed(42)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s"
)

env = gym.make("Pendulum-v1")

## PPO Algorithm

The following snapshot shows the pseudocode of the PPO-Clip algorithm:

<img src="../figures/ppo-clip.png" width=500px></img>

## Environment Specifications

In [86]:
def get_state_dim(env: Env) -> Optional[int]:
    
    try:
        return env.observation_space.shape[0]
    except:
        return None

def get_action_dim(env: Env) -> Optional[int]:
    
    try:
        return env.action_space.shape[0]
    except:
        return None
    
print(f"state space dim: {get_state_dim(env)}")
print(f"action space dim: {get_action_dim(env)}")

state space dim: 3
action space dim: 1


In [87]:
@dataclass
class ActionSpec:
    dim: int
    min: np.ndarray
    max: np.ndarray
    
    @classmethod
    def from_env(cls, env: Env) -> Self:
        return cls(
            dim=env.action_space.shape[0],
            min=env.action_space.__dict__["low"],
            max=env.action_space.__dict__["high"]
        )

state_dim = get_state_dim(env)
action_spec = ActionSpec.from_env(env)

print(f"state dim: {state_dim}")
print(f"action spec: {action_spec}")

state dim: 3
action spec: ActionSpec(dim=1, min=array([-2.], dtype=float32), max=array([2.], dtype=float32))


## Actor and Critic Models

In [88]:
class MLP(nn.Module):
    
    def __init__(
            self,
            in_dim: int,
            out_dim: int
        ) -> None:
        
        super().__init__()
        
        self.layer1 = nn.Linear(in_dim, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, out_dim)
    
    def forward(self, state: Tensor) -> Tensor:
        
        x = self.layer1(state)
        x = F.relu(x)
        x = self.layer2(x)
        x = F.relu(x)
        x = self.layer3(x)
        
        return x


In [89]:
class Actor(nn.Module):
    
    def __init__(
            self,
            state_dim: int,
            action_spec: ActionSpec
        ) -> None:
        
        super().__init__()
        
        
        self._action_min = torch.tensor(action_spec.min, dtype=torch.float)
        self._action_max = torch.tensor(action_spec.max, dtype=torch.float)
        
        self._distribution = None
        self._action = None
        
        self.mlp = MLP(
            in_dim=state_dim,
            out_dim=action_spec.dim
        )
        
        self.dist_param1 = nn.Sequential(
            nn.Linear(action_spec.dim, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, action_spec.dim),
            nn.Sigmoid()
        )
        
        self.dist_param2 = nn.Sequential(
            nn.Linear(action_spec.dim, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, action_spec.dim),
            nn.Sigmoid()
        )
        
    @property
    def distribution(self) -> Optional[Distribution]:
        
        return self._distribution
        
    
    def forward(self, state: Tensor) -> Tensor:
        
        out: Tensor = self.mlp(state)
    
        param1 = self.dist_param1(out)
        param2 = self.dist_param2(out)
        
        mean = F.sigmoid(out) * (self._action_max - self._action_min) + self._action_min
        # mean = param1 * (self._action_max - self._action_min) + self._action_min
        # std = param2
        std = torch.ones_like(mean) * 0.5
        
        # mean = F.sigmoid(out[..., 0]) * (self._action_max - self._action_min) + self._action_min
        # std = torch.ones_like(mean) * 0.5
        
        # Generate the distribution
        distribution = Normal(mean, std)
        
        # Store the distribution
        self._distribution = distribution
        
        return distribution
    
    def select_action(
            self, 
            state: Optional[Tensor] = None
        ) -> Tensor:
        
        if state is not None:
            self.forward(state)

        # Sample an action from the distribution
        action = self._distribution.sample()
        
        # Clip the action value by upper and lower bounds
        action = action.clip(self._action_min, self._action_max)
        
        # Store the action taken
        self._action = action
        
        return action
    
    def log_prob(
            self, 
            action: Optional[Tensor] = None,
            state: Optional[Tensor] = None
        ) -> Tensor:
        
        if action is None:
            assert state is None,\
                "the state must be set None since the action is None"
            return self._distribution.log_prob(self._action).sum(dim=-1)
        
        if state is not None:
            self.forward(state)
            
        return self._distribution.log_prob(action).sum(dim=-1)

actor = Actor(state_dim, action_spec)
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float)
action = actor.select_action(state)
log_prob = actor.log_prob()

print(f"action: {action}")
print(f"log-probability: {log_prob}")

action: tensor([0.4161])
log-probability: -0.7357034683227539


In [90]:
actor.distribution

Normal(loc: tensor([-0.0888], grad_fn=<AddBackward0>), scale: tensor([0.5000]))

In [91]:
class Critic(nn.Module):
    
    def __init__(self, state_dim: int) -> None:
        
        super().__init__()
        
        
        self.mlp = MLP(
            in_dim=state_dim,
            out_dim=1
        )

    def forward(self, state: Tensor) -> Tensor:
        
        return self.evaluate(state)
    
    def evaluate(self, state: Tensor) -> Tensor:
        
        out: Tensor = self.mlp(state)
        value = out.squeeze(dim=-1)
        
        return value

critic = Critic(state_dim)
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float)
value = critic(state)

print(f"estimated state value: {value}")

estimated state value: 0.1734076738357544


In [92]:
def compute_rewards_to_go(rewards: list[float], gamma: float) -> list[float]:
    
    rewards = rewards.copy()
    rewards_to_go = deque()
    
    while len(rewards) > 0:
        # Iterate the rewards from back to front
        reward = rewards.pop()
        
        # The value of reward-to-go 
        # associated with the next state
        if len(rewards_to_go) > 0:
            next_reward_to_go = rewards_to_go[0]
        else:
            next_reward_to_go = 0
        
        # Compute reward-go-to associated with current state
        reward_to_go = reward + gamma * next_reward_to_go
        
        # Add to the front of the queue
        rewards_to_go.appendleft(reward_to_go)
    
    # Convert to list
    rewards_to_go = list(rewards_to_go)
    
    return rewards_to_go

compute_rewards_to_go([1, 10, 100], gamma=0.1)

[3.0, 20.0, 100.0]

In [93]:
def compute_advantages(
        critic: Critic,
        rewards_to_go: list[float], 
        states: list[np.ndarray]
    ):
    
    # State values
    states = torch.tensor(np.array(states), dtype=torch.float)
    values = critic.evaluate(states)
    values = values.detach().numpy()
    
    # Compute advantages
    rewards_to_go = np.array(rewards_to_go)
    advantages: np.ndarray = rewards_to_go - values
    
    # Normalize the advatages
    advantages = stats.zscore(advantages)
    
    return advantages

advantages = compute_advantages(
    critic,
    [1, 2, 5],
    [state] * 3
)

print(f"advantages: {advantages}")

advantages: [-0.98058067 -0.39223227  1.37281295]


## Transition Data

In [94]:
Transition = namedtuple(
    "Transition",
    (
        "state",
        "action",
        "reward",
        "reward_to_go",
        "log_prob"
    )
)

def convert_to_transition_with_fields_as_lists(transitions: list[Transition]) -> Transition:
    
    return Transition(*map(list, zip(*transitions)))
    
def convert_to_transitions(transition_with_fields_as_list: Transition) -> Transition:
    
    return list(map(
        lambda fields: Transition(*fields), 
        zip(*transition_with_fields_as_list)
    ))


In [95]:
def play_one_episode(
        env: Env,
        actor: Actor,
        max_n_timesteps_per_episode: int,
        gamma: float = 0.99
    ) -> Transition:
    
    states = []
    actions = []
    rewards = []
    log_probs = []

    state, _ = env.reset()
    is_done = False
    
    for t in range(max_n_timesteps_per_episode):
        
        # Select an action
        state = torch.tensor(state, dtype=torch.float)
        action = actor.select_action(state)
        action = action.detach().numpy()
        
        # Compute the log-probability of the action taken
        log_prob = actor.log_prob()
        log_prob = log_prob.detach().item()
        
        # Interact with the env
        next_state, reward, is_terminated, is_truncated, _ = env.step(action)
        
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        log_probs.append(log_prob)
        
        # The episode ends
        is_done = is_terminated or is_truncated
        if is_done:
            break
        
        # Step to the next state
        state = next_state
    
    # Compute the rewards-to-go 
    # based on the received rewards of entire episode
    rewards_to_go = compute_rewards_to_go(rewards, gamma)
    
    return Transition(
        state=states,
        action=actions,
        reward=rewards,
        reward_to_go=rewards_to_go,
        log_prob=log_probs,
    )


In [96]:
class ReplayBuffer(deque, Dataset):
    
    def __init__(
            self,
            env: Env,
            capacity: int,
            max_n_timesteps_per_episode: int,
            gamma: float
        ) -> None:
        
        super().__init__(maxlen=capacity)

        self._env = env
        self._capacity = capacity
        self._max_n_timesteps_per_episode = max_n_timesteps_per_episode
        self._gamma = gamma
    
    @property
    def capacity(self) -> int:
        return self._capacity
    
    @property
    def max_n_timesteps_per_episode(self) -> int:
        return self._max_n_timesteps_per_episode
        
    def collect(self, actor: Actor) -> None:
        
        while len(self) < self._capacity:
            
            transition_with_fields_as_list = play_one_episode(
                env=self._env,
                actor=actor,
                max_n_timesteps_per_episode=self._max_n_timesteps_per_episode,
                gamma=self._gamma
            )
            
            # Convert to list of transitions
            transitions = convert_to_transitions(transition_with_fields_as_list)
            
            # Add to the buffer
            self.extend(transitions)


## Updating Actor and Critic

In [97]:
actor_opt = Adam(actor.parameters())
critic_opt = Adam(critic.parameters())

In [98]:
def update_actor_critic(
        actor: Actor,
        critic: Critic,
        actor_opt: Optimizer,
        critic_opt: Optimizer,
        replay_buffer_loader: DataLoader,
        n_epochs: int,
        epsilon: float = 0.2
    ):
    
    for _ in range(n_epochs):
        transition: Transition
        for transition in replay_buffer_loader:
            
            batch_states = transition.state
            batch_actions = transition.action
            
            log_porbs = actor.log_prob(batch_actions, batch_states)
            batch_log_probs = transition.log_prob
            
            ratios = torch.exp(log_porbs - batch_log_probs)
            
            batch_rewards_to_go = transition.reward_to_go.type(torch.float)
            state_values = critic.evaluate(batch_states)
            advantages = batch_rewards_to_go - state_values.detach()
        
            surr1 = ratios * advantages
            surr2 = torch.clip(
                ratios,
                1 - epsilon,
                1 + epsilon
            ) * advantages
            
            actor_loss = -torch.min(surr1, surr2).mean()
            
            # Update actor
            actor_opt.zero_grad()
            actor_loss.backward()
            actor_opt.step()
            
            
            critic_loss = F.mse_loss(batch_rewards_to_go, state_values)

            # Update critic
            critic_opt.zero_grad()
            critic_loss.backward()
            critic_opt.step()


In [99]:
@dataclass
class PPOConfig:
    
    # Total number of epochs of the PPO algorithm
    n_epochs: int
    
    replay_buffer_capacity: int
    max_n_timesteps_per_episode: int
    
    # Discount factor
    gamma: float
    
    # Batch size of the data loader
    batch_size: int
    
    # Number of epochs to update actor and critc networks
    n_epochs_for_actor_critic: int
    
    # Learning rate of the Adam optimizers
    lr: float


## PPO Class

In [100]:
class PPO:
    
    def __init__(
            self, 
            env: Env,
            config: PPOConfig
        ) -> None:
        
        self._env = env
        self._config = config
        state_dim = get_state_dim(env)
        action_spec = ActionSpec.from_env(env)
        
        # Actor and critic networks
        self._actor = Actor(state_dim, action_spec)
        self._critic = Critic(state_dim)
        
        # Optimizers
        self._actor_opt = Adam(self._actor.parameters(), lr=self._config.lr)
        self._critic_opt = Adam(self._critic.parameters(), lr=self._config.lr)
        
        # Replay buffer
        self._replay_buffer = ReplayBuffer(
            env=env,
            capacity=self._config.replay_buffer_capacity,
            max_n_timesteps_per_episode=self._config.max_n_timesteps_per_episode,
            gamma=self._config.gamma
        )
    
    @property
    def actor(self) -> Actor:
        return self._actor
    
    @property
    def critic(self) -> Critic:
        return self._critic
    
    def learn(self):
        
        for epoch in range(self._config.n_epochs):
            
            logging.info(f"PPO epoch: {epoch + 1}")
            avg_episode_rewards = []
            
            # Collect transitions
            self._replay_buffer.collect(self._actor)
            
            rewards = []
            for transition in self._replay_buffer:
                rewards.append(transition.reward)
            avg_episode_reward = np.mean(rewards)
            avg_episode_rewards.append(avg_episode_reward)
            logging.info(f"average episode rewards: {avg_episode_reward}")
            
            # Create a data loader
            replay_buffer_loader = DataLoader(
                self._replay_buffer,
                batch_size=self._config.batch_size
            )
            
            # Train the actor and critic   
            update_actor_critic(
                actor=self._actor,
                critic=self._critic,
                actor_opt=self._actor_opt,
                critic_opt=self._critic_opt,
                replay_buffer_loader=replay_buffer_loader,
                n_epochs=self._config.n_epochs_for_actor_critic
            )
            
            # Clear replay buffer
            self._replay_buffer.clear()


## Training

In [102]:
ppo = PPO(
    env,
    config=PPOConfig(
        n_epochs=100,
        replay_buffer_capacity=4800,
        max_n_timesteps_per_episode=1600,
        gamma=0.99,
        batch_size=1024,
        n_epochs_for_actor_critic=10,
        lr=0.01
    )
)

ppo.learn()

2023-09-21 14:43:53,491 | INFO | PPO epoch: 1
2023-09-21 14:43:54,746 | INFO | average episode rewards: -6.110923935401077
2023-09-21 14:43:55,076 | INFO | PPO epoch: 2
2023-09-21 14:43:56,339 | INFO | average episode rewards: -6.071260639772841
2023-09-21 14:43:56,656 | INFO | PPO epoch: 3
2023-09-21 14:43:57,937 | INFO | average episode rewards: -6.286279557600787
2023-09-21 14:43:58,241 | INFO | PPO epoch: 4
2023-09-21 14:43:59,476 | INFO | average episode rewards: -5.504663757900611
2023-09-21 14:43:59,785 | INFO | PPO epoch: 5
2023-09-21 14:44:01,035 | INFO | average episode rewards: -6.153235834902657
2023-09-21 14:44:01,327 | INFO | PPO epoch: 6
2023-09-21 14:44:02,545 | INFO | average episode rewards: -6.730246981997852
2023-09-21 14:44:02,837 | INFO | PPO epoch: 7
2023-09-21 14:44:04,076 | INFO | average episode rewards: -5.464440702286453
2023-09-21 14:44:04,458 | INFO | PPO epoch: 8
2023-09-21 14:44:05,671 | INFO | average episode rewards: -5.593644874240346
2023-09-21 14:44

In [103]:
torch.save(ppo.actor, "../models/actor.pth")
torch.save(ppo.critic, "../models/critic.pth")

## Loading Models

In [104]:
actor = torch.load("../models/actor.pth")
critic = torch.load("../models/critic.pth")
actor.eval()
critic.eval()

print(actor)
print(critic)

Actor(
  (mlp): MLP(
    (layer1): Linear(in_features=3, out_features=64, bias=True)
    (layer2): Linear(in_features=64, out_features=64, bias=True)
    (layer3): Linear(in_features=64, out_features=1, bias=True)
  )
  (dist_param1): Sequential(
    (0): Linear(in_features=1, out_features=16, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=16, out_features=1, bias=True)
    (3): Sigmoid()
  )
  (dist_param2): Sequential(
    (0): Linear(in_features=1, out_features=16, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=16, out_features=1, bias=True)
    (3): Sigmoid()
  )
)
Critic(
  (mlp): MLP(
    (layer1): Linear(in_features=3, out_features=64, bias=True)
    (layer2): Linear(in_features=64, out_features=64, bias=True)
    (layer3): Linear(in_features=64, out_features=1, bias=True)
  )
)


## Performance of the Trained Agent

In [105]:
env = gym.make(
    "Pendulum-v1", 
    max_episode_steps=1000,
    render_mode="rgb_array"
)

env = RecordVideo(
    env, 
    video_folder="../_static/videos",
    name_prefix=f"ppo-{env.spec.id}"
)

state, _ = env.reset()

for t in count():
    
    # Select an action
    state = torch.tensor(state, dtype=torch.float)
    action = actor.select_action(state)
    action = action.detach().numpy()
    
    # Interactive with the environment
    next_state, reward, is_terminated, is_truncated, info = env.step(action)
    
    is_done = is_terminated or is_truncated
    
    if is_done:
        break
    
     # Step to the next state
    state = next_state

env.close()

  logger.warn(


Moviepy - Building video /Users/isaac/Developer/py-projects/linguAML/book/_static/videos/ppo-Pendulum-v1-episode-0.mp4.
Moviepy - Writing video /Users/isaac/Developer/py-projects/linguAML/book/_static/videos/ppo-Pendulum-v1-episode-0.mp4



                                                                

Moviepy - Done !
Moviepy - video ready /Users/isaac/Developer/py-projects/linguAML/book/_static/videos/ppo-Pendulum-v1-episode-0.mp4


In [106]:
Video("../_static/videos/ppo-Pendulum-v1-episode-0.mp4")