Laying the Foundations for Vision and Multimodal Mechanistic Interpretability & Open Problems

Behold the dogit lens. Patch-level logit attribution is an emergent segmentation map.

Join our Discord here.

This article was written by Sonia Joseph, in collaboration with Neel Nanda, and incubated in Blake Richards’s lab at Mila and in the MATS community. Thank you to the Prisma core contributors, including Praneet Suresh, Rob Graham, and Yash Vadi.

Full acknowledgements of contributors are at the end. I am grateful to my collaborators for their guidance and feedback.

Outline

  • Part One: Introduction and Motivation

  • Part Two: Tutorial Notebooks

  • Part Three: Brief ViT Overview

  • Part Four: Demo of Prisma’s Functionality

    • Key features, including logit attribution, attention head visualization, and activation patching.

    • Preliminary research results obtained using Prisma, including emergent segmentation maps and canonical attention heads.

  • Part Five: FAQ, including Key Differences between Vision and Language Mechanistic Interpretability

  • Part Six: Getting Started with Vision Mechanistic Interpretability

  • Part Seven: How to Get Involved

  • Part Eight: Open Problems in Vision Mechanistic Interpretability

Introducing the Prisma Library for Multimodal Mechanistic Interpretability

I am excited to share with the mechanistic interpretability and alignment communities a project I’ve been working on for the last few months. Prisma is a multimodal mechanistic interpretability library based on TransformerLens, currently supporting vanilla vision transformers (ViTs) and their vision-text counterparts CLIP.

With recent rapid releases of multimodal models, including Sora, Gemini, and Claude 3, it is crucial that interpretability and safety efforts remain in tandem. While language mechanistic interpretability already has strong conceptual foundations, many research papers, and a thriving community, research in non-language modalities lags behind. Given that multimodal capabilities will be part of AGI, field-building in mechanistic interpretability for non-language modalities is crucial for safety and alignment.

The goal of Prisma is to make research in mechanistic interpretability for multimodal models both easy and fun. We are also building a strong and collaborative open source research community around Prisma. You can join our Discord here.

This post includes a brief overview of the library, fleshes out some concrete problems, and gives steps for people to get started.

Prisma Goals

  1. Build shared infrastructure (Prisma) to make it easy to run standard language mechanistic interpretability techniques on non-language modalities, starting with vision.

  2. Build shared conceptual foundation for multimodal mechanistic interpretability.

  3. Shape and execute on research agenda for multimodal mechanistic interpretability.

  4. Build an amazing multimodal mechanistic interpretability subcommunity, inspired by current efforts in language.

    1. Set the cultural norms of this subcommunity to be highly collaborative, curious, inventive, friendly, respectful, prolific, and safety/​alignment-conscious.

    2. Encourage sharing of early/​scrappy research results on Discord/​Less Wrong.

    3. Co-create a web of high-quality research.

Tutorial Notebooks

To get started, you can check out three tutorial notebooks that show how Prisma works.

  1. Main ViT DemoOverview of main mechanistic interpretability technique on a ViT, including direct logit attribution, attention head visualization, and activation patching. The activation patching switches the net’s prediction from tabby cat to Border collie with a minimum ablation.

  2. Emoji Logit LensDeeper dive into layer- and patch-level predictions with interactive plots.

  3. Interactive Attention Head TourDeeper dive into the various types of attention heads a ViT contains with interactive JavaScript.

Brief ViT Overview

From “An Image is Worth 16x16 words: Transformers for Image Recognition at Scale.”

A vision transformer (ViT) is an architecture designed for image classification tasks, similar to the classic transformer architecture used in language models. A ViT consists of transformer blocks; each block consists of an Attention layer and an MLP layer.

Unlike language models, vision transformers do not have a dictionary-style embedding and unembedding matrix. Instead, images are divided into non-overlapping patches, similar to tokens in language models. These patches are flattened and linearly projected to embeddings via a Conv2D layer, akin to word embeddings in language models. A learnable class token (CLS token) is prepended at the start of the sequence, which accrues global information throughout the network. A linear position embedding is added to the patches.

The patch embeddings then pass through the transformer blocks (each block consists of a LayerNorm, an Attention layer, another LayerNorm, and an MLP layer). The output of each block is added back to the previous input. The sum of the block’s output and its previous input is called the residual stream.

The final layer of this vision transformer is a classification head with 1000 logit values for ImageNet’s 1000 classes. The CLS token is fed into the final layer for 1000-way classification. Adapting TransformerLens, we designed HookedViT to easily capture intermediate activations with custom hook functions, instead of dealing with PyTorch’s normal hook functionality.

Prisma Functionality

We’ll demonstrate the functionality with some preliminary research results. The plots are all interactive but the LW site does not let me render HTML. See the original post for interactive graphs.

Emoji Logit Lens

The emoji logit lens is a convenient way to visualize patch-level predictions for each layer of the net.

We treat every patch like the CLS token, and feed it into the ViT’s 1000-way classification head that’s pre-trained on ImageNet, without fine-tuning. This is the equivalent to deleting all layers between the layer of your choice and the output classification head.

For convenience, we represent the ImageNet prediction of that patch with its corresponding emoji, drawing from our ImageNet-Emoji Dictionary.

Below are the patch-level predictions of the final layer of a ViT for an image of a cat sitting inside a toilet. The yellow means that the logit prediction was high and blue means the logit prediction was low (see the Emoji Logit Lens notebook for more details).

Emergent Segmentation

One of my favorite findings so far is that the patch-level logit lens on the image basically acts as a segmentation map. For the image above, the cat patches get classified as cat, and the toilet patches get classified as a toilet!

This is not an obvious result, as vision transformers are optimized to predict a single class with the CLS token, and not segment the image. The segmentation is an emergent property. See the Emoji Logit Lens Notebook for more details and an interactive visualization.

Similar emergent segmentation capabilities were recently reported by Gandelsman, Efros, and Steinhardt (2024), who found that decomposing CLIP’s image representation across spatial locations allowed obtaining zero-shot semantic segmentation masks that outperformed prior methods. Our results extend this finding to vanilla vision transformers and provide an intuitive visualization using the emoji logit lens.

We can see similar results on other images.

(Note: For visualization purposes, I’ve changed the coloring to be by emoji class instead of logit value like above; see Emoji Logit Lens notebook for details.)

Image and patch-level logit lens of a cheetah chasing a gazelle.
Image and patch-level logit lens of two children with a basketball playing near baby lions.

Interestingly, the net has some biased predictions (“abaya” for the children, perhaps due to their ethnicity), one consequence of only having a 1000-class vocabulary to span concept-space.


Funnily, the net thinks that the center of the green apple (above image, bottom left) is a bagel.


When we do a layer-by-layer logit lens, we see the net’s evolving predictions:

Interestingly, the net picks up on the “animal” at 9_pre (the residual stream before the 9th transformer block) but classifies the cat as a dog. The net only catches onto the cat at 10_pre.

This layer-wise analysis builds upon the work of Gandelsman, Efros, and Steinhardt (2024), who used mean ablations to identify which layers in CLIP have the most significant direct effect on the final representation. Our emoji logit lens provides a complementary view, visualizing how the patch-level predictions evolve across the model’s depth.

Interactive code here.

We can also visualize the evolving per-patch predictions for the above cat/​toilet image for all the layers at once:

Direct Logit Attribution

The library supports direct logit attribution, including at the layer-level and attention-level.

Below, the net starts making a distinction between tabby/​collie and banana at the eighth layer. See the ViT Prisma Main Demo for the interactive graph.

Attention Heads

I wrote an interactive JavaScript visualizer so we can see what each vision attention head is attending to on the image.

The x and y axes of the attention head are the flattened image. The image is 50 patches in total, including the CLS token, which means the total attention head is a 50x50 square.

Upon initial inspection, the first layer’s attention heads are extremely geometric.

Corner Head, Edges Head, and Modulus Head

We can see attention heads’ scores specializing for specific patterns in the data, including what we call a Corner Head, an Edges Head, and a Modulus Head. This is fascinating because the flattened image does not explicitly contain corner, edge, or row/​column information; detecting these patterns is emergent from training.

These findings echo the recent work of Gandelsman, Efros, and Steinhardt (2024) who identified property-specific attention heads in CLIP that specialize in concepts like colors, locations, and shapes. Our results suggest that such specialization is a more general property of vision transformer architectures, including vanilla models trained solely on image classification, and includes even more basic geometric properties like the coordinates of the image.

Interactive code here.


Video of the Corner Head

Activation Patching

Prisma has the activation patching functionality of TransformerLens.

The Cat-Dog Switch

I found a single attention head (Layer 11, Head 4) wherein patching the CLS token of the z-matrix flips the computation from tabby cat to Border Collie. The CLS token in that z-matrix aggregates patch-level cat ear/​face information from the attention pattern.

Our activation patching results demonstrate that this technique can be used to flip the model’s prediction by targeting specific heads, providing a powerful tool for understanding and manipulating the model’s decision-making process

This result resonates with Gandelsman, Efros, and Steinhardt (2024), who showed that knowledge of head-specific roles in CLIP can be used to manually intervene in the model’s computation, such as removing heads associated with spurious cues.

Interactive code here.

Toy Vision Transformers

We are releasing nine tiny ViTs for testing (equivalent to TransformerLens’ gelu-1l) to better isolate behavior. These tiny ViTs were trained by Yash Vadi and Praneet Suresh.

The repo also contains training code to quickly train custom toy ViTs.

HookedViT

We currently support timm’s vanilla ViTs, TinyCLIP, the video vision transformer, and our own custom tiny transformers. More models will come soon based on demand!

FAQ

Is multimodal mechanistic interpretability really that different from language?

Yes and no. Vision mech interpretability is like language mechanistic interpretability, but in a fun-house mirror. Both architectures are transformers, so many LLM techniques carry over. However, there are a few twists:

  • The typical ViT is not doing unidirectional sequence modeling. ViTs use bidirectional attention and predict a global CLS token, rather than predicting the next token in an autoregressive manner. (Note: There are autoregressive vision transformers with basically the same architecture as language, such as Image GPT and Parti, which do next-token image generation. However, as of February 2024, autoregressive vision transformers are not frequently used in the wild.)

  • Bidirectional attention vs causal attention. Language transformers have causal (unidirectional) attention—i.e. there is an upper triangular mask on the attention, so that earlier tokens cannot attend to tokens in the future. However, the classical ViT has bidirectional attention. Thus, the ViT does not have the same concept of “time” as language transformers, and some of the original language mechanistic interpretability techniques break. It can be unclear which direction information is flowing. Induction heads if they are present in vision, would look different from those in language to account for bidirectional attention.

  • CLS token instead of next token prediction/​ autoregressive loss. For ViTs, a learnable CLS token, which is prepended to the input, gets fed into the classification head instead of the final token as in language. The CLS token accrues global information from the other patches through self-attention as all the patches pass through the net.

  • No canonical dictionary matrix. Vision is more ambiguous and lacks the standard dictionary matrix like the 50k one for language. For instance, a yellow patch on a goldfinch might represent “yellow,” “wing,” “goldfinch,” “bird,” or “animal,” depending on the granularity, showing hierarchical ambiguity. An animal might be identified specifically as a “Border collie” or more generally as a “dog.” Beyond hierarchy, ambiguity in vision also stems from cultural interpretations and the imprecision of language. Practically, ImageNet’s 1000 classes serve as a makeshift “dictionary,” but it falls short of fully encompassing visual concepts.

  • Additional hyperparameters. Patch size is a vision-specific hyperparameter, determining the size of the patches into which an image is divided. Using smaller patches increases accuracy but also computational load, because attention scales quadratically with patch number.

  • There is a zoo of vision transformers. Similar to language, vision transformers come in many forms. The most relevant are the vanilla ViT; CLIP, which is co-trained with text using contrastive loss; and DINO, which uses unlabeled data. There is also a gallery of loss functions used, including classification loss and masked autoencoder loss. Different losses may lead to different emergent symmetries in the model, although this is an open question. For a review of important ViT architectures, check out this survey.

If there is demand, I may write up a post giving a deeper and more theoretical take on the differences on language vs non-language mechanistic interpretability.

Why start with vision transformers?

Vision transformers have an extremely similar architecture to language transformers, so many of the existing techniques transfer over cleanly.

Diffusion models are the next obvious frontier, but there will be a larger conceptual leap in designing mechanistic interpretability techniques, largely due to their iterative denoising process. I’d be happy to collaborate on this with anyone who is serious about building strong conceptual foundations here.

Getting Started with Vision Mechanistic Interpretability

  1. It is wise to first have a strong foundation in language mechanistic interpretability, so check out the loads of resources already on this forum. The ARENA curriculum is a good place to start.

  2. Check out these tutorial notebooks:

    1. Main ViT Demo—Includes direct logit attribution, activation patching, and other standard mech interp techniques. The notebook has a section that switches a ViT’s prediction from tabby cat to Border Collie with a minimum ablation.

    2. Emoji Logit Lens

    3. Interactive Attention Head Tour

    4. Optional:

      1. Mindreader. A viewer for maximally activating images on TinyCLIP, which gives a better intuition of how the model hierarchically processes information.

  3. Check out my brief paper liston Vision Transformers and Mechanistic Interpretability for general context.

How to get involved

  1. Use and contribute to the Prisma repo. Check out our open Issues.

    1. Check out our Spring 2024 Roadmap

  2. Community. Join the Prisma Discord.

  3. Collaboration. Work on the Open Problems below.

    1. Post more Open Problems in the comments, or discuss what excites you in particular!

  4. Get mentorship. I would be happy to mentor people on any of the Open Problems above, or a new one that you propose. However, please ensure your proposal is well-thought-out and includes preliminary results on any of the medium to hard problems. Feel free to reach out to me on the LW forum or the Prisma Discord.

  5. Funding for open source alignment and interpretability. As an open-source project, we value the support of our community and sponsors. Open source funding helps us cover expenses and invest in the project’s growth. Feel free to reach out if you’d like to discuss potential funding opportunities!

Open Problems in Vision Mechanistic Interpretability

Here are some Open Problems to get started. If inspired, you are encouraged to post your own in the comments, or comment on the ideas that most grab your attention.

Easy and Exploratory

  1. Explore interpretable neurons in Mindreader:

  2. Play with the attention head JavaScript visualizers to visually inspect attention heads.

    • Do any heads seem to pick up on the same patterns across images?

    • Find 1-3 attention heads that appear to capture interesting features.

  3. Generate maximally activating images for a ViT/​CLIP. The maximally activating images were used to make the Mindreader. I will provide sample code to identify the maximally activating images for each neuron on request.

    • Experiment with different models and datasets for finding maximally activating images.

    • Plot histograms of neuron activations to assess false positives and negatives.

  4. Run the logit lens on various images in the Main Demo notebook.

    • Are certain ImageNet classes identified earlier in the net than others?

  5. Load in a toy ViT into Prisma and examine the attention heads. Do you notice any invariant attention scores/​patterns for certain types of images?

    • Look at the attention heads of the 3-layer toy ViT.

    • Find canonical circuit patterns in the 1-4 layer ViTs. (Note these ViTs are patch size 16 whose attention is currently not feasible to render in the JavaScript attention head viewer.)

Expanding Techniques to New Architectures and Datasets

  1. Swap CLIP /​ Vanilla ViTs. Run the techniques above but swap the architecture.

    • Run the Main Demo on TinyCLIP instead of a vanilla ViT (which is pretrained only on 1k-ImageNet, instead of TinyCLIP’s image/​text).

    • For the maximally activating images task (Problem 3), use a vanilla ViT instead of TinyCLIP.

  2. Video Vision Transformer. Run the techniques on a video vision transformer, which you can specify in the HookedViTConfig here (thanks Rob Graham for the idea and adapting the model to Prisma).

    • How do you account for time as a new dimension? How is the model representing time?

  3. Patch-level labels. Use patch-level labels for your dataset.

    • ImageNet is currently only labeled at the class level. Often, this is too coarse-grained. For example, a man eating a burrito, which is labeled “burrito” for ImageNet, has many components: the man, the burrito, his shirt, and the cutlery on the table. Patch-level labels were created by Rob Graham using SAM on ImageNet Images and give a boolean mask for every object in the image. See our Huggingface link for more details on patch-level labels and the SAM pipeline.

    • Generating maximally activating images (Problem 3) is a good candidate as a task for patch-level labels, because now your neurons’ labels will be much more fine-grained than merely using ImageNet class-level labels.

Deeper Investigations

  1. Finish off the cat/​dog circuit from the Main Demo Notebook.

    • Run linear probes on the notebook/​scratchpad token to the right of the Border Collie. How much cat vs dog information does the patch contain at each layer of the net?

    • How general is Attention Head 4, Layer 11 (the “Cat-Dog Decider Head”), which appears to be pushing the net’s decision from Border Collie to tabby cat? Does the attention head make the same decision for other images containing both cats and dogs?

    • What is the full circuit for cats and dogs, according to the rigorous definition of a circuit?

  2. Attention patterns vs attention scores.

    • In the Interactive Attention Head Tour, the attention scores (pre-softmax) sometimes look more visually interpretable than the attention patterns (post-softmax). How do we connect our observations about the attention scores to our observations about the attention patterns?

  3. Tuned Lens. Use the Tuned Lens instead of the classical logit lens for the Emoji Logit Lens (train a probe for each block). Does the Tuned Lens improve interpretability?

    • Recreate the “Layer-Level Logit Lens” plots in the Emoji Logit Lens notebook using Tuned Lenses instead of the current vanilla logit lens. Do the results corroborate each other?

  4. Attention Ablations: Change the ViT’s bidirectional attention to unidirectional (like language models). How does this affect segmentation maps and information flow?

  5. Adding registers. Try adding registers as in Darcet et al (2023) and see what happens to the segmentation map.

  6. Superposition. Find superpositioned ViT neurons and disentangle the layer with SAEs

  7. Circuit Identification. Do full circuit identification for simple naturalistic data like MNIST.

  8. Disentanglement datasets. Run the disentanglement dataset dSprites through the model. Does the internal representation of the net show disentanglement? We have some pre-trained dSprites transformers by Yash Vadi here.

  9. Are there induction heads in vision?

    1. Induction heads emerge with two or more layers in language. Is there an analog in vision (or the emergence of some other useful symmetry)?

    2. The symmetries may depend a lot on loss function (e.g. masked autoencoder losens may yield different symmetries than classification loss, although this is an open question).

    3. Vision attention is bidirectional, so it’s less obvious what “induction” means here. The canonical definition from language breaks down.

  10. Reverse engineer textual inversion—Add a new, made-up word to the CLIP text encoder and fine-tune with 4-6 corresponding images (e.g. you can finetune the model to label your face with your name). How do the model’s internals change?

Advanced Investigations

  1. Vision training dynamics and phase transitions.

    1. Detect canonical phase transitions in ViT training loss curves (analogous to induction head loss bumps).

    2. Praneet Suresh found that reconstruction loss, and visualizing the reconstructed image throughout training, is a convenient way to detect phase transitions.

  2. Compare the interpretability of CLIP vs a vanilla ViT.

    1. There is an unproven intuition that CLIP is more interpretable than a vanilla ViT. CLIP has better labels than a vanilla ViT. CLIP co-optimized with captions, which are more granular labels. For example, “the tabby cat sat on the window” (a CLIP-style label) is more precise than the plain ImageNet class “tabby cat.” The higher-quality labels may result in “better-factored” internal representations.

    2. How interpretable is a ViT in comparison to CLIP? You explore this question by checking maximally activating images on both models, and running the Logit Lens notebook on TinyCLIP instead of a ViT. Brainstorm your own techniques to compare the internal representations of the models.

  3. Patch information.

    1. How does local information in spatial patches propagate to the global CLS token? Could we get a circuit-level breakdown?

    2. Why do the patches in the upper layer of CLIP store global information? How are they computed from local patches?

  4. Reverse-Engineering Vision Transformer Registers

    1. Vision transformer registers (Darcet et. al (2023)) were a recent phenomenon in the vision community where adding blank tokens dramatically improved the attention maps of a ViT

    2. On a low-level, mechanistic basis, how does local information from patch tokens propagate to the global tokens, including the CLS token?

    3. Try removing a few register tokens, see what breaks. Are there heads that are specialized to attend to register tokens?

    4. Train an SAE on register tokens and see if you can disentangle what they store. Also run linear probes on register tokens.

  5. Visual reasoning. Use a dataset like CLEVR to see if CLIP does visual reasoning. Does CLIP have visual reasoning “circuits”? Note: CLIP may be bad at CLEVR, which is a complex dataset. Try creating your own simpler visual reasoning tasks (e.g. 2 apples + 2 apples = 4 apples), as a baseline.

    1. Create an open source dataset with very simple reasoning tasks for this purpose. This would be a service to the broader research community!

  6. Replicate the results of Gandelsman et al. Do you notice attention heads specializing for certain semantic groups?

  7. Other architectures

    1. Explore models like Flamingo.

    2. Lay groundwork to analyze diffusion models.

Acknowledgements

Thank you to this most excellent mosaic of communities.

Thank you to Praneet Suresh, Rob Graham, and Yash Vadi, and the other core contributors to the Prisma Repo. Thank you to my PI, Blake Richards, and the rest of our lab at Mila for their support and feedback.

Thank you to Neel Nanda for guidance in bringing mechanistic interpretability to another modality, to Joseph Bloom for your advice on building a repo, to Arthur Conmy for coining the term “dogit lens,” and to the rest of the MATS community for your feedback.

Thank you to the Prisma group at Mila for your feedback, including Santoshi Ravichandran, Ali Kuwajerwala, Mats L. Richter, and Luca Scimeca; members of LiNCLab, including Arna Ghosh and Dan Levenstein; and members of CERC-AAI lab, including Irina Rish and Ethan Caballero. Thank you to Karolis Ramanauskas, Noah MacCallum, Rob Graham, and Romeo Valentin for your feedback on the tutorial notebooks.

Finally, thank you to the South Park Commons community for your support, including Ker Lee Yap, Abhay Kashyap, Jonathan Brebner, and Ruchi Sanghvi and Aditya Agarwal.

This research was generously supported by Blake Richards’s lab, which was funded by the Bank of Montreal; NSERC (Discovery Grant: RGPIN-2020-05105; Discovery Accelerator Supplement: RGPAS-2020-00031; Arthur B. McDonald Fellowship: 566355-2022); CIFAR (Canada AI Chair; Learning in Machine and Brains Fellowship); and a Canada Excellence Research Chair Award to Prof. Irina Rish; and by South Park Commons. This research was enabled in part by support provided by Calcul Québec and the Digital Research Alliance of Canada. We acknowledge the material support of NVIDIA in the form of computational resources.