OpenAI Solves (Some) Formal Math Olympiad Problems

Epistemic status: I have just skimmed through OpenAI’s blogpost and paper, I do not fully understand the details.

From the blogpost

We built a neural theorem prover for Lean that learned to solve a variety of challenging high-school olympiad problems, including problems from the AMC12 and AIME competitions, as well as two problems adapted from the IMO.
[...]
The prover uses a language model to find proofs of formal statements. Each time we find a new proof, we use it as new training data, which improves the neural network and enables it to iteratively find solutions to harder and harder statements.

From the paper

We explore the use of expert iteration in the context of language modeling applied to formal mathematics. We show that at same compute budget, expert iteration, by which we mean proof search interleaved with learning, dramatically outperforms proof search only. We also observe that when applied to a collection of formal statements of sufficiently varied difficulty, expert iteration is capable of finding and solving a curriculum of increasingly difficult problems, without the need for associated ground-truth proofs. Finally, by applying this expert iteration to a manually curated set of problem statements, we achieve state-of-the-art on the miniF2F benchmark, automatically solving multiple challenging problems drawn from high school olympiads.

Method

  • Uses the Lean formal environment instead of the Metamath used in GPT-f.

  • Uses “decoder-only Transformers similar to GPT-3” with 774M trainable parameters

  • Pre-trained “successively on GPT-3’s postprocessed version of CommonCrawl (for 300B tokens) and an updated version of WebMath (for 72B tokens)”

  • “proof search interleaved with learning”

The two IMO-adapted problems

Problem 1: Suppose a, b, c are the sides of a triangle. Prove that a^2(b + c − a) + b^2(c + a − b) + c^2(a + b − c) ≤ 3abc.
Problem 2: For a, b, c reals, prove that (a^2 + ab + b^2)(b^2 + bc + c^2)(c^2 + ca + a^2) ≥ (ab+bc+ ca)^3.

Both solutions to those problems use “nlinarith” applied to the right arguments, which, as far as I understand, is a tactic from mathlib for solving nonlinear arithmetic problems by adding more assumptions to the context of the solver. (source)

The right arguments for the first problem are said in the blogpost to come (informally) from Schur’s inequality, which gives

nlinarith [sq_nonneg (b—a), sq_nonneg (c—b), sq_nonneg (c—a)]

The second problem is solved by applying the Cauchy-Schwarz multiple times, then using some inequality it “invented”, and ends up with the same nlinarith expression above.