pradachan's picture
Upload folder using huggingface_hub
f71c233 verified
raw
history blame
1.26 kB
import numpy as np
class QLearningAgent:
def __init__(self, lr=0.1, gamma=0.95, epsilon=0.1):
self.lr = lr
self.gamma = gamma
self.epsilon = epsilon
self.initial_epsilon = epsilon
self.epsilon_decay = 0.99 # Decay rate for epsilon
self.q_table = {}
def get_state(self, val_loss, current_lr):
return (round(val_loss.item(), 2), round(current_lr, 5))
def choose_action(self, state):
if np.random.rand() < self.epsilon:
self.epsilon *= self.epsilon_decay # Decay epsilon
return np.random.choice([-1, 0, 1])
if state not in self.q_table:
self.q_table[state] = [0, 0, 0]
return np.argmax(self.q_table[state]) - 1
def update_q_values(self, state, action, reward, next_state):
if state not in self.q_table:
self.q_table[state] = [0, 0, 0]
if next_state not in self.q_table:
self.q_table[next_state] = [0, 0, 0]
best_next_action = np.argmax(self.q_table[next_state])
td_target = reward + self.gamma * self.q_table[next_state][best_next_action]
td_error = td_target - self.q_table[state][action + 1]
self.q_table[state][action + 1] += self.lr * td_error