[Proposal] Method of locating useful subnets in large models

I’ve seen it suggested (e.g, here) that we could tackle the outer alignment problem by using interpretability tools to locate the learned “human values” subnet of powerful, unaligned models. Here I outline a general method of extracting such subnets from a large model.

Suppose we have a large, unaligned model M. We want to extract a small subnet from M that is useful for a certain task (T), which could be modeling human values, translating languages, etc. My proposal for finding such a subnet is to train a “subnet extraction” (SE) model through reinforcement learning.

We’d provide SE with access to M’s weights as well as an evaluation dataset for T. Presumably, SE would be a perceiver or similar architecture able to handle very large inputs and would only process small parts of M at a time.

During training, SE would select a small subnet from M, then the subnet would be sandwiched inside an “assister model” (AM), which consists of a pretrained encoder, followed by randomly initialized layers, followed by the extracted subnet, followed by more randomly initialized layers. The AM is then finetuned on the dataset for T as well as on a set of distractor datasets, {D_1, … D_n}. SE would get reward for AM’s post-finetuning performance on T’s dataset minus its average performance on the distractor datasets and be penalized according to the size of the subnet.

R = finetune(AM, T’s dataset) - a avg_val{finetune(AM, D_i), for i 1 to n} - b |subnet|

Where a and b are hyperparameters.

The idea is that the subnet M uses to solve T can be easily adapted to solve T in other contexts. It’s possible such subnets rely on features generated by other parts of M. That’s why I sandwich the subnet in AM. It’s supposed to provide the subnet with generic features, so that SE doesn’t have to extract those generic features from M.

I include the distractor datasets to ensure SE learns to extract subnets that are specific to the task/​dataset provided and not just extract subnets from M that are really good for learning any task. I encourage SE to extract smaller subnets because I expect smaller subnets will be easier to analyze with other interpretability tools and because I think smaller subnets are less likely to include risky spillover from M (e.g., mesa optimizers). During training, we’d cycle the evaluation task/​dataset with the aim that SE learn to be a general subnet extractor for whatever dataset it’s given.

When we want to extract the human values subnet, we’d give SE a dataset that we think requires human value modeling to solve. We’d then continue SE’s training process, providing SE with reward for the subnets it extracts. Potentially, we could increase b over time to prompt SE to extract the minimum size subnet that represents human values.

One potential risk is that there’s a mesa optimizer in M that trains SE to extract it by being very good at T while deliberately failing on the distractors. To address this issue, we can compare the subnets extracted for various tasks/​datasets to see if they share weights and add a term to SE’s reward that encourages diversity in the subnets it extracts for different tasks.