vinci rufus

Recurrent Neural Networks- A Technical Deep Dive

2024-11-15
4 minutes

Introduction: The Computational Graph Unfolded

At their core Recurrent Neural Networks (RNN), operate on sequences by applying the same set of weights recursively over time-steps. Let’s dive deep into their architecture and understand why they’re unreasonably effective at sequence modeling tasks.

The Mathematics Behind RNNs

Basic RNN Architecture

The fundamental RNN computation can be expressed as:

class RNN:
    def step(self, x, h):
        # Update the hidden state
        h_new = np.tanh(np.dot(W_hh, h) + np.dot(W_xh, x) + b_h)
        # Compute the output vector
        y = np.dot(W_hy, h_new) + b_y
        return h_new, y

Here’s what’s happening in detail:

  • W_hh: Weights for hidden-to-hidden connections
  • W_xh: Weights for input-to-hidden connections
  • W_hy: Weights for hidden-to-output connections
  • b_h, b_y: Bias terms
  • h: Hidden state vector
  • x: Input vector
  • y: Output vector

Forward Pass and Backpropagation Through Time (BPTT)

The real magic happens during training. Let’s implement a basic version:

def forward_backward_pass(inputs, targets, h_prev):
    # Forward pass
    h_states = []
    outputs = []
    h = h_prev
    loss = 0
    
    # Forward pass
    for t in range(len(inputs)):
        h, y = step(inputs[t], h)
        h_states.append(h)
        outputs.append(y)
        loss += -np.log(y[targets[t]])  # Cross-entropy loss
    
    # Backward pass
    dW_hh, dW_xh, dW_hy = np.zeros_like(W_hh), np.zeros_like(W_xh), np.zeros_like(W_hy)
    db_h, db_y = np.zeros_like(b_h), np.zeros_like(b_y)
    dh_next = np.zeros_like(h_states[0])
    
    for t in reversed(range(len(inputs))):
        # Gradient computation goes here
        # This is where BPTT happens
        pass
    
    return loss, dW_hh, dW_xh, dW_hy, db_h, db_y

Character-Level Language Model: A Concrete Example

Let’s implement a character-level language model to demonstrate the power of RNNs:

class CharRNN:
    def __init__(self, vocab_size, hidden_size):
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        
        # Initialize weights
        self.W_hh = np.random.randn(hidden_size, hidden_size) * 0.01
        self.W_xh = np.random.randn(hidden_size, vocab_size) * 0.01
        self.W_hy = np.random.randn(vocab_size, hidden_size) * 0.01
        self.b_h = np.zeros((hidden_size, 1))
        self.b_y = np.zeros((vocab_size, 1))
    
    def sample(self, h, seed_ix, n):
        x = np.zeros((self.vocab_size, 1))
        x[seed_ix] = 1
        generated = []
        
        for t in range(n):
            h, y = self.step(x, h)
            p = np.exp(y) / np.sum(np.exp(y))
            ix = np.random.choice(range(self.vocab_size), p=p.ravel())
            x = np.zeros((self.vocab_size, 1))
            x[ix] = 1
            generated.append(ix)
            
        return generated

The Unreasonable Effectiveness in Practice

1. Text Generation

When trained on a large corpus of text, our character-level model learns:

  • Proper spelling and word formation
  • Basic grammar and punctuation
  • Context-appropriate vocabulary
  • Genre-specific writing styles

Here’s what makes this particularly remarkable:

  • The model only sees one character at a time
  • It has no built-in understanding of words or grammar
  • It learns everything from statistical patterns in the sequence

2. Source Code Generation

RNNs can even learn the syntax and patterns of programming languages. For example:

def generate_code(model, seed="def"):
    return model.sample(seed, length=1000)

The model learns:

  • Proper indentation
  • Matching parentheses and brackets
  • Function and variable naming conventions
  • Basic programming patterns

3. Mathematics of Memory

The hidden state h acts as the network’s memory. At each time step t:

h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b_h)

This recursive formula allows the network to:

  • Maintain long-term dependencies
  • Forget irrelevant information
  • Build hierarchical representations

Advanced Topics: Dealing with Vanishing Gradients

LSTM Cells

The Long Short-Term Memory (LSTM) architecture addresses the vanishing gradient problem:

def lstm_step(x, h_prev, c_prev):
    # Gates
    f = sigmoid(W_f.dot(x) + U_f.dot(h_prev) + b_f)
    i = sigmoid(W_i.dot(x) + U_i.dot(h_prev) + b_i)
    o = sigmoid(W_o.dot(x) + U_o.dot(h_prev) + b_o)
    # New memory content
    g = tanh(W_g.dot(x) + U_g.dot(h_prev) + b_g)
    # Update cell state
    c = f * c_prev + i * g
    # Update hidden state
    h = o * tanh(c)
    return h, c

Gradient Clipping

To prevent exploding gradients:

def clip_gradients(gradients, max_norm=5):
    norm = np.sqrt(sum(np.sum(grad ** 2) for grad in gradients))
    if norm > max_norm:
        scale = max_norm / norm
        return [grad * scale for grad in gradients]
    return gradients

Practical Tips for Training RNNs

  1. Initialization: Use small random weights to prevent saturation:
W = np.random.randn(n_in, n_out) * np.sqrt(2.0/n_in)
  1. Mini-batch Processing: Implement batch processing for efficiency:
def process_batch(batch_inputs, batch_size):
    h = np.zeros((batch_size, hidden_size))
    for t in range(seq_length):
        h = step(batch_inputs[t], h)
  1. Learning Rate Schedule: Implement adaptive learning rates:
learning_rate = base_lr * decay_rate ** (epoch / decay_steps)

Beyond Simple Sequences

RNNs can be extended to handle more complex patterns:

  1. Bidirectional RNNs: Process sequences in both directions
  2. Deep RNNs: Stack multiple RNN layers
  3. Attention Mechanisms: Allow the network to focus on relevant parts of the input sequence

Conclusion

The effectiveness of RNNs comes from their ability to learn complex patterns through simple, recursive operations. While newer architectures like Transformers have emerged, the fundamental insights from RNNs continue to influence deep learning design.

Understanding their mathematical foundations and implementation details helps us appreciate why they work so well and how to use them effectively in practice.