Agent#

from rich import print
import random

from linguaml.tolearn.family import Family
from linguaml.tolearn.hp.bounds import NumericHPBounds
from linguaml.rl.action import ActionConfig, Action
from linguaml.rl.state import StateConfig, State

family = Family.SVC

numeric_hp_bounds = NumericHPBounds.from_dict({
    "C": (0.1, 100),
    "gamma": (0.1, 100),
    "tol": (1e-5, 1e-3)
})

ActionConfig.family = family
ActionConfig.numeric_hp_bounds = numeric_hp_bounds

StateConfig.lookback = 5

def generate_random_action():
    
    # Create an empty action
    action = Action()
    
    # Continuous actions
    for hp_name in family.numeric_hp_names():
            
        # Generate a random number in [0, 1]
        action[hp_name] = random.random()
        
    # Discrete actions
    for hp_name in family.categorical_hp_names():
        
        # Get the number of levels in the category
        n_levels = family.n_levels_in_category(hp_name)
        
        # Generate a random integer in [0, n_levels - 1]
        action[hp_name] = random.randint(0, n_levels - 1)
    
    return action

def generate_random_actions(n: int) -> list[Action]:
    
    actions = [
        generate_random_action()
        for _ in range(n)
    ]
    
    return actions

def generate_random_state() -> State:
    
    actions = generate_random_actions(StateConfig.lookback)
    
    rewards = [
        random.random()
        for _ in range(StateConfig.lookback)
    ]
    
    state = State.from_actions_and_rewards(actions, rewards)
    
    return state

def generate_random_states(n: int) -> list[State]:
    
    states = [
        generate_random_state()
        for _ in range(n)
    ]
    
    return states
generate_random_action()
{'C': 0.709065814257011,
 'gamma': 0.908919155394657,
 'tol': 0.8326055084880106,
 'kernel': 3,
 'decision_function_shape': 2}

Selecting Actions#

Random Actions#

from linguaml.rl.agent import Agent, ContinuousDistributionFamily
from linguaml.tolearn.family import Family

agent = Agent(
    family=Family.SVC,
    numeric_hp_bounds={
        "C": (0.1, 100),
        "gamma": (0.1, 100),
        "tol": (1e-5, 1e-3)
    },
    hidden_size=128,
    cont_dist_family=ContinuousDistributionFamily.NORMAL
)

print(agent.select_random_action())
{
    'C': 0.7254745717836685,
    'gamma': 0.414141520089093,
    'tol': 0.6514806593112382,
    'kernel': 1,
    'decision_function_shape': 0
}

Single Action#

Create an agent:

from linguaml.rl.agent import Agent, ContinuousDistributionFamily
from linguaml.tolearn.family import Family

# Create an agent
agent = Agent(
    family=Family.SVC,
    numeric_hp_bounds={
        "C": (0.1, 100),
        "gamma": (0.1, 100),
        "tol": (1e-5, 1e-3)
    },
    hidden_size=128,
    cont_dist_family=ContinuousDistributionFamily.NORMAL
)

# Generate a random state
state = generate_random_state()

# Select an action
action = agent.select_action(state)

print(action)
{'C': 0.5211094617843628, 'kernel': 1, 'gamma': 0.0, 'tol': 0.7661858797073364, 'decision_function_shape': 0}

Batched Actions#

from linguaml.rl.agent import Agent, ContinuousDistributionFamily
from linguaml.tolearn.family import Family

agent = Agent(
    family=Family.SVC,
    numeric_hp_bounds={
        "C": (0.1, 100),
        "gamma": (0.1, 100),
        "tol": (1e-5, 1e-3)
    },
    hidden_size=128,
    cont_dist_family=ContinuousDistributionFamily.NORMAL
)
from linguaml.rl.state import StateConfig, State, BatchedStates
import random

# Set the look back period
StateConfig.lookback = 5

# Generate a list of random states
states = generate_random_states(10)

# Convert to batched states
batched_states = BatchedStates.from_states(states)
print(f"shape of data of batched states: {batched_states.data.shape}")

# Select actions
batched_actions = agent.select_action(batched_states)
print(batched_actions)
print(batched_actions.to_hp_configs())
shape of data of batched states: (10, 5, 10)
{
    'C': array([0.        , 0.4485766 , 0.74467874, 0.26534462, 0.33302683,
       0.        , 0.01054871, 0.7600914 , 0.21722707, 0.24913226],
      dtype=float32),
    'kernel': array([2, 1, 0, 3, 3, 1, 3, 3, 1, 3]),
    'gamma': array([0.51459336, 0.43040276, 0.88630176, 0.        , 1.        ,
       0.28879407, 0.5592177 , 0.        , 1.        , 0.5656874 ],
      dtype=float32),
    'tol': array([0.1928125 , 0.        , 0.        , 0.7702479 , 1.        ,
       0.26864272, 0.        , 0.27716446, 0.72063833, 0.        ],
      dtype=float32),
    'decision_function_shape': array([1, 1, 1, 1, 0, 1, 0, 0, 1, 0])
}
[
    SVCConfig(
        C=0.1,
        kernel='rbf',
        gamma=51.507876944541934,
        tol=0.00020088437736034392,
        decision_function_shape='ovr'
    ),
    SVCConfig(
        C=44.912802276015285,
        kernel='poly',
        gamma=43.09723529815674,
        tol=1e-05,
        decision_function_shape='ovr'
    ),
    SVCConfig(
        C=74.49340569972992,
        kernel='linear',
        gamma=88.64154541492462,
        tol=1e-05,
        decision_function_shape='ovr'
    ),
    SVCConfig(
        C=26.60792751312256,
        kernel='sigmoid',
        gamma=0.1,
        tol=0.0007725453978776932,
        decision_function_shape='ovr'
    ),
    SVCConfig(C=33.36937995553017, kernel='sigmoid', gamma=100.0, tol=0.001, decision_function_shape='ovo'),
    SVCConfig(
        C=0.1,
        kernel='poly',
        gamma=28.95052764117718,
        tol=0.00027595629632472994,
        decision_function_shape='ovr'
    ),
    SVCConfig(
        C=1.1538162112236023,
        kernel='sigmoid',
        gamma=55.965847373008735,
        tol=1e-05,
        decision_function_shape='ovo'
    ),
    SVCConfig(
        C=76.03313325643539,
        kernel='sigmoid',
        gamma=0.1,
        tol=0.0002843928146362305,
        decision_function_shape='ovo'
    ),
    SVCConfig(
        C=21.800984445214276,
        kernel='poly',
        gamma=100.0,
        tol=0.0007234319514036179,
        decision_function_shape='ovr'
    ),
    SVCConfig(
        C=24.98831284195185,
        kernel='sigmoid',
        gamma=56.61217305660248,
        tol=1e-05,
        decision_function_shape='ovo'
    )
]

Log-Probabilities#

Single Data#

from linguaml.rl.agent import Agent, ContinuousDistributionFamily
from linguaml.tolearn.family import Family

# Create an agent
agent = Agent(
    family=Family.SVC,
    numeric_hp_bounds={
        "C": (0.1, 100),
        "gamma": (0.1, 100),
        "tol": (1e-5, 1e-3)
    },
    hidden_size=128,
    cont_dist_family=ContinuousDistributionFamily.NORMAL
)

# Create a random state
state = generate_random_state()

# Select an action
action = agent.select_action(state)

Without providing the argument state in agent’s method log_prob, we compute the log-probability of the action taken based on the latest state:

# Get the log probability of the action based on the latest state
log_prob = agent.log_prob(action)

print(log_prob)
tensor(-3.5283, grad_fn=<SumBackward1>)

Of course, this is equivalent to:

# Get the log probability of the action based on the provided state
log_prob = agent.log_prob(action, state)

print(log_prob)
tensor(-3.5283, grad_fn=<SumBackward1>)

But by passing the state, the agent regenerated the distributions for selecting the actions by calling the forward method. Hence, you may neglect the argument state and save some time if you indeed want to compute the log-probability based on the latest state.

Batched Data#

from linguaml.rl.agent import Agent
from linguaml.tolearn.family import Family
from linguaml.rl.state import BatchedStates

# Create an agent
agent = Agent(
    family=Family.SVC,
    numeric_hp_bounds={
        "C": (0.1, 100),
        "gamma": (0.1, 100),
        "tol": (1e-5, 1e-3)
    },
    hidden_size=128,
    cont_dist_family=ContinuousDistributionFamily.NORMAL
)


# Create a random batch of states
batched_states = BatchedStates.from_states(generate_random_states(10))

# Select batched actions
batched_actions = agent.select_action(batched_states)

# Compute the log probabilities of the batched actions
log_probs = agent.log_prob(batched_actions)

print(log_probs)
tensor([-3.7806, -3.1261, -3.6266, -3.9272, -3.4137, -3.3805, -4.0454, -3.8760,
        -3.6255, -3.4482], grad_fn=<SumBackward1>)