Recurrent Neural Networks- A Technical Deep Dive
Introduction: The Computational Graph Unfolded
At their core Recurrent Neural Networks (RNN), operate on sequences by applying the same set of weights recursively over time-steps. Let’s dive deep into their architecture and understand why they’re unreasonably effective at sequence modeling tasks.
The Mathematics Behind RNNs
Basic RNN Architecture
The fundamental RNN computation can be expressed as:
class RNN:
def step(self, x, h):
# Update the hidden state
h_new = np.tanh(np.dot(W_hh, h) + np.dot(W_xh, x) + b_h)
# Compute the output vector
y = np.dot(W_hy, h_new) + b_y
return h_new, y
Here’s what’s happening in detail:
W_hh
: Weights for hidden-to-hidden connectionsW_xh
: Weights for input-to-hidden connectionsW_hy
: Weights for hidden-to-output connectionsb_h
,b_y
: Bias termsh
: Hidden state vectorx
: Input vectory
: Output vector
Forward Pass and Backpropagation Through Time (BPTT)
The real magic happens during training. Let’s implement a basic version:
def forward_backward_pass(inputs, targets, h_prev):
# Forward pass
h_states = []
outputs = []
h = h_prev
loss = 0
# Forward pass
for t in range(len(inputs)):
h, y = step(inputs[t], h)
h_states.append(h)
outputs.append(y)
loss += -np.log(y[targets[t]]) # Cross-entropy loss
# Backward pass
dW_hh, dW_xh, dW_hy = np.zeros_like(W_hh), np.zeros_like(W_xh), np.zeros_like(W_hy)
db_h, db_y = np.zeros_like(b_h), np.zeros_like(b_y)
dh_next = np.zeros_like(h_states[0])
for t in reversed(range(len(inputs))):
# Gradient computation goes here
# This is where BPTT happens
pass
return loss, dW_hh, dW_xh, dW_hy, db_h, db_y
Character-Level Language Model: A Concrete Example
Let’s implement a character-level language model to demonstrate the power of RNNs:
class CharRNN:
def __init__(self, vocab_size, hidden_size):
self.hidden_size = hidden_size
self.vocab_size = vocab_size
# Initialize weights
self.W_hh = np.random.randn(hidden_size, hidden_size) * 0.01
self.W_xh = np.random.randn(hidden_size, vocab_size) * 0.01
self.W_hy = np.random.randn(vocab_size, hidden_size) * 0.01
self.b_h = np.zeros((hidden_size, 1))
self.b_y = np.zeros((vocab_size, 1))
def sample(self, h, seed_ix, n):
x = np.zeros((self.vocab_size, 1))
x[seed_ix] = 1
generated = []
for t in range(n):
h, y = self.step(x, h)
p = np.exp(y) / np.sum(np.exp(y))
ix = np.random.choice(range(self.vocab_size), p=p.ravel())
x = np.zeros((self.vocab_size, 1))
x[ix] = 1
generated.append(ix)
return generated
The Unreasonable Effectiveness in Practice
1. Text Generation
When trained on a large corpus of text, our character-level model learns:
- Proper spelling and word formation
- Basic grammar and punctuation
- Context-appropriate vocabulary
- Genre-specific writing styles
Here’s what makes this particularly remarkable:
- The model only sees one character at a time
- It has no built-in understanding of words or grammar
- It learns everything from statistical patterns in the sequence
2. Source Code Generation
RNNs can even learn the syntax and patterns of programming languages. For example:
def generate_code(model, seed="def"):
return model.sample(seed, length=1000)
The model learns:
- Proper indentation
- Matching parentheses and brackets
- Function and variable naming conventions
- Basic programming patterns
3. Mathematics of Memory
The hidden state h
acts as the network’s memory. At each time step t:
h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b_h)
This recursive formula allows the network to:
- Maintain long-term dependencies
- Forget irrelevant information
- Build hierarchical representations
Advanced Topics: Dealing with Vanishing Gradients
LSTM Cells
The Long Short-Term Memory (LSTM) architecture addresses the vanishing gradient problem:
def lstm_step(x, h_prev, c_prev):
# Gates
f = sigmoid(W_f.dot(x) + U_f.dot(h_prev) + b_f)
i = sigmoid(W_i.dot(x) + U_i.dot(h_prev) + b_i)
o = sigmoid(W_o.dot(x) + U_o.dot(h_prev) + b_o)
# New memory content
g = tanh(W_g.dot(x) + U_g.dot(h_prev) + b_g)
# Update cell state
c = f * c_prev + i * g
# Update hidden state
h = o * tanh(c)
return h, c
Gradient Clipping
To prevent exploding gradients:
def clip_gradients(gradients, max_norm=5):
norm = np.sqrt(sum(np.sum(grad ** 2) for grad in gradients))
if norm > max_norm:
scale = max_norm / norm
return [grad * scale for grad in gradients]
return gradients
Practical Tips for Training RNNs
- Initialization: Use small random weights to prevent saturation:
W = np.random.randn(n_in, n_out) * np.sqrt(2.0/n_in)
- Mini-batch Processing: Implement batch processing for efficiency:
def process_batch(batch_inputs, batch_size):
h = np.zeros((batch_size, hidden_size))
for t in range(seq_length):
h = step(batch_inputs[t], h)
- Learning Rate Schedule: Implement adaptive learning rates:
learning_rate = base_lr * decay_rate ** (epoch / decay_steps)
Beyond Simple Sequences
RNNs can be extended to handle more complex patterns:
- Bidirectional RNNs: Process sequences in both directions
- Deep RNNs: Stack multiple RNN layers
- Attention Mechanisms: Allow the network to focus on relevant parts of the input sequence
Conclusion
The effectiveness of RNNs comes from their ability to learn complex patterns through simple, recursive operations. While newer architectures like Transformers have emerged, the fundamental insights from RNNs continue to influence deep learning design.
Understanding their mathematical foundations and implementation details helps us appreciate why they work so well and how to use them effectively in practice.