Using GPT-4 to Understand Code

Introduction

I’ve been using GPT-4 to interpret papers and code, and it’s been extremely helpful. Today I’m reading Toy Models of Superposition by Anthropic and following along in the Colab notebook. Below is a conversation with ChatGPT (with GPT-4 as the backend) where I pasted in the first section of code from the Colab notebook and then asked it questions about the code.

The purpose of this post is to show one way that GPT-4 can be useful for learning. I’d be curious to hear how others are using it too!

Some takeaways:

  • It does a really good job of constructing concrete examples, and it can construct as many different examples as you need. (Like the one about news articles below.)

  • It doesn’t need much context to figure out what’s going on. (All it’s given is this one code snippet. I don’t think it’s aware of the actual paper since it was published after September 2021. The final paragraph below shows how just much it gleans.)

  • It makes nontrivial inferences. (That the model is an autoencoder.)

  • Its execution of the code is wrong. (When producing an example batch the values in the tensor are wrong. But the shape and high level structure are right.)

  • Overall though its interpretations seem by and large correct.

Conversation with GPT-4

Me:
Consider the following code:

@dataclass
class Config:
  n_features: int
  n_hidden: int

  # We optimize n_instances models in a single training loop
  # to let us sweep over sparsity or importance curves 
  # efficiently.

  # We could potentially use torch.vmap instead.
  n_instances: int
 
class Model(nn.Module):
  def __init__(self, 
               config, 
               feature_probability: Optional[torch.Tensor] = None,
               importance: Optional[torch.Tensor] = None,               
               device='cuda'):
    super().__init__()
    self.config = config
    self.W = nn.Parameter(torch.empty((config.n_instances, config.n_features, config.n_hidden), device=device))
    nn.init.xavier_normal_(self.W)
    self.b_final = nn.Parameter(torch.zeros((config.n_instances, config.n_features), device=device))

    if feature_probability is None:
      feature_probability = torch.ones(())
    self.feature_probability = feature_probability.to(device)
    if importance is None:
      importance = torch.ones(())
    self.importance = importance.to(device)

  def forward(self, features):
    # features: [..., instance, n_features]
    # W: [instance, n_features, n_hidden]
    hidden = torch.einsum("...if,ifh->...ih", features, self.W)
    out = torch.einsum("...ih,ifh->...if", hidden, self.W)
    out = out + self.b_final
    out = F.relu(out)
    return out

  def generate_batch(self, n_batch):
    feat = torch.rand((n_batch, self.config.n_instances, self.config.n_features), device=self.W.device)
    batch = torch.where(
        torch.rand((n_batch, self.config.n_instances, self.config.n_features), device=self.W.device) <= self.feature_probability,
        feat,
        torch.zeros((), device=self.W.device),
    )
    return batch

def linear_lr(step, steps):
  return (1 - (step / steps))

def constant_lr(*_):
  return 1.0

def cosine_decay_lr(step, steps):
  return np.cos(0.5 * np.pi * step / (steps - 1))

def optimize(model, 
             render=False, 
             n_batch=1024,
             steps=10_000,
             print_freq=100,
             lr=1e-3,
             lr_scale=constant_lr,
             hooks=[]):
  cfg = model.config

  opt = torch.optim.AdamW(list(model.parameters()), lr=lr)

  start = time.time()
  with trange(steps) as t:
    for step in t:
      step_lr = lr * lr_scale(step, steps)
      for group in opt.param_groups:
        group['lr'] = step_lr
      opt.zero_grad(set_to_none=True)
      batch = model.generate_batch(n_batch)
      out = model(batch)
      error = (model.importance*(batch.abs() - out)**2)
      loss = einops.reduce(error, 'b i f -> i', 'mean').sum()
      loss.backward()
      opt.step()
    
      if hooks:
        hook_data = dict(model=model,
                         step=step, 
                         opt=opt,
                         error=error,
                         loss=loss,
                         lr=step_lr)
        for h in hooks:
          h(hook_data)
      if step % print_freq == 0 or (step + 1 == steps):
        t.set_postfix(
            loss=loss.item() / cfg.n_instances,
            lr=step_lr,
        )


config = Config(
    n_features = 5,
    n_hidden = 2,
    n_instances = 10,
)

model = Model(
    config=config,
    device=DEVICE,
    # Exponential feature importance curve from 1 to 1/100
    importance = (0.9**torch.arange(config.n_features))[None, :],
    # Sweep feature frequency across the instances from 1 (fully dense) to 1/20
    feature_probability = (20 ** -torch.linspace(0, 1, config.n_instances))[:, None]
)

optimize(model)

Explain what it’s doing in simple terms.

GPT:
This code defines and trains a simple neural network model to learn sparse patterns of input data. Here’s a high-level breakdown of what it does:

1. The code defines a Config class to store configuration settings for the neural network, such as the number of input features, the number of hidden neurons, and the number of instances (models) being optimized at once.

2. A Model class is defined, which is a neural network model using PyTorch. This class has a constructor that initializes the model’s parameters and an implementation of the forward pass. The generate_batch method is used to create a batch of input data with varying sparsity levels.

3. Three learning rate scheduling functions are defined: linear_lr, constant_lr, and cosine_decay_lr. These functions will be used to scale the learning rate during training.

4. The optimize function is defined to train the model. This function takes several parameters, such as the model, number of training steps, learning rate, and more. During training, the optimizer minimizes the mean squared error between the input data and the model’s output, while taking the importance of the features into account.

5. A config object is created using the Config class, with specified values for the number of features, hidden neurons, and instances.

6. A model object is created using the Model class, with the specified configuration and additional parameters, such as the importance of features and feature probability.

7. Finally, the optimize function is called with the model object to train the model.

In simpler terms, this code defines and trains a neural network model to learn from sparse input data, where the importance and presence of features vary across instances. The model is trained to minimize the error between its output and the input data, taking the importance of the features into account.

Me:
I don’t get what sparse means here

GPT:
In the context of this code, “sparse” refers to the presence or absence of features in the input data. When data is sparse, it means that most of the values are zeros or close to zero, while only a few of the values are non-zero.

In this specific code, the generate_batch method creates input data with varying sparsity levels. The sparsity is controlled by the feature_probability parameter, which represents the probability of a feature being present (non-zero) in the input data. A lower feature_probability value indicates higher sparsity, meaning more zeros in the input data.

For example, if the input data has 5 features, and the feature_probability is 0.2 for an instance, it means there is a 20% chance that each feature will be present (non-zero) in the input data for that instance. In this case, the input data will be sparse, as most of the feature values will be zero.

The purpose of the code is to train the neural network model to learn patterns in the input data while taking the sparsity and importance of features into account. The model is expected to perform well on instances with different levels of sparsity, since it is trained on a variety of input data with varying sparsity levels.

Me:
Show me an example input

GPT:
Sure! Let’s create an example input using the generate_batch method from the code. First, let’s create a model using the given configuration:

import torch

config = Config(
    n_features = 5,
    n_hidden = 2,
    n_instances = 10,
)

model = Model(
    config=config,
    # Exponential feature importance curve from 1 to 1/100
    importance = (0.9**torch.arange(config.n_features))[None, :],
    # Sweep feature frequency across the instances from 1 (fully dense) to 1/20
    feature_probability = (20 ** -torch.linspace(0, 1, config.n_instances))[:, None]
)

Now, let’s generate an example input batch of size 1 using the generate_batch method:

example_input = model.generate_batch(n_batch=1)
print(example_input)

The output might look like this:

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.9446, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4761, 0.0000, 0.8479, 0.0000, 0.0000]]])

In this example, the input batch has a shape of (1, 10, 5), representing 1 batch, 10 instances, and 5 features. As you can see, most of the values are zeros, making the input data sparse. The sparsity level varies across instances (rows), as controlled by the feature_probability parameter. For instance, the last row has more non-zero values, while the other rows are mostly zeros.

Me:
Give me a real world example of what each of the features could represent

GPT:
Let’s consider a real-world example from the field of natural language processing, specifically text classification. Suppose you’re building a model to classify news articles into different categories like sports, politics, technology, etc. The input features represent the presence of specific keywords or phrases in the articles.

Each feature could represent the frequency or presence of a specific keyword in the article. For example:

1. Feature 1: The presence of the keyword “election”.
2. Feature 2: The presence of the keyword “technology”.
3. Feature 3: The presence of the keyword “sports”.
4. Feature 4: The presence of the keyword “economy”.
5. Feature 5: The presence of the keyword “health”.

In this case, sparse input data means that an article only contains a few of these keywords, while most of the keywords are not present. For instance, a sports article might have the keyword “sports” but not the keywords “election” or “technology”. The neural network model would then learn patterns in the presence of these keywords to classify the articles into the appropriate categories.

In the context of this example, different instances could represent different news sources, each with varying levels of sparsity in the use of these keywords. The model would be trained to perform well on sources with different levels of sparsity, as controlled by the feature_probability parameter.

Me:
What is the model predicting?

GPT:
The model is designed to predict the output values for the given input features while considering the sparsity and importance of those features. The output is a tensor with the same shape as the input features.

In the forward pass of the model, the input features are transformed through a two-layer neural network with a ReLU activation function in the output layer. The loss function used for optimization is the mean squared error between the input data and the model’s output, weighted by the importance of each feature.

In simple terms, the model is trying to learn a representation of the input data that captures the underlying patterns and structure while considering the sparsity and importance of the features. The goal of the model is to minimize the error between the input features and the output it generates, essentially trying to reconstruct the input data as closely as possible. The predicted output can be thought of as the model’s best approximation of the input data after learning the patterns and structure in the data.

It’s important to note that this model is not designed for a specific prediction task, such as classification or regression. Instead, it serves as an example of a neural network learning from sparse input data with varying levels of sparsity and feature importance.