This proof achieves, using the Jensen-Shannon divergence (“JS”), what the previous one failed to show using KL divergence (“DKL”). In fact, while the previous attempt tried to show only that redundancy is conserved (in terms of DKL) upon resampling latents, this proof shows that the redundancy and mediation conditions are conserved (in terms of JS).
Why Jensen-Shannon?
In just about all of our previous work, we have used DKL as our factorization error. (The error meant to capture the extent to which a given distribution fails to factor according to some graphical structure.) In this post I use the Jensen Shannon divergence.
DKL(U||V):=EUlnUV
JS(U||V):=12DKL(U||U+V2)+12DKL(V||U+V2)
The KL divergence is a pretty fundamental quantity in information theory, and is used all over the place. (JS is usually defined in terms of DKL, as above.) We have pretty strong intuitions about what DKL means and it has lots of nice properties which I won’t go into detail about, but we have considered it a strong default when trying to quantify the extent to which two distributions differ.
The JS divergence looks somewhat ad-hoc by comparison. It also has some nice mathematical properties (its square root is a metric, a feature sorely lacking from DKL) and there is some reason to like it intuitively: JS(U||V) is equivalent to the mutual information between X, a variable randomly sampled from one of the distributions, and Z, an indicator which determines the distribution X gets sampled from. So in this sense it captures the extent to which a sample distinguishes between the two distributions.
Ultimately, though, we want a more solid justification for our choice of error function going forward.
This proof works, but it uses JS rather than DKL. Is that a problem? Can/Should we switch everything over to JS? We aren’t sure. Some of our focus for immediate next steps is going to be on how to better determine the “right” error function for comparing distributions for the purpose of working with (natural) latents.
And now, for the proof:
Definitions
Let P be any distribution over X and Λ.
I will omit the subscripts if the distribution at hand is the full joint distribution with all variables unbound. I.e.PX,Λ is the same as P. When variables are bound, they will be written as lower case in the subscript. When this is still ambiguous, the full bracket notation will be used.
First, define auxiliary distributions Q, S, R, and M:
Q, S, and M each perfectly satisfy one of the (stochastic) Natural Latent conditions, with Q and S each satisfying one of the redundancy conditions (X2→X1→Λ, and X1→X2→Λ, respectively,) and M satisfying the mediation condition (X1←Λ→X2).
R represents the distribution when both of the redundancy factorizations are applied in series to P.
Let Γ be a latent variable defined by P[Γ=γ|X]:=P[Λ=γ|X1]=P[Γ=γ|X1], with PΓ:=PX,ΛPΓ|X
Now, define the auxiliary distributions QΓ, SΓ, and MΓ, similarly as above, and show some useful relationships to P, Q, S, R, and M:
For any distribution P over (X, Λ), the latent Γ∼P[Λ|Xi] has redundancy error of zero on one of it’s factorizations, while the other factorization errors are bounded by small factor of the errors induced by Λ. More formally:
∀P[X,Λ], the latent Γ defined by P[Γ=γ|X]:=P[Λ|X1] has bounded factorization errors ϵΓ1=0 and max(ϵΓ2,ϵΓmed)≤5(ϵ1+ϵ2+ϵmed).
In fact, that is a simpler but looser bound than that proven below which achieves the more bespoke bounds of: ϵΓ1=0, ϵΓ2≤(2√ϵ1+√ϵ2)2, and ϵΓmed≤(2√ϵ1+√ϵmed)2.
Proof
(1) ϵΓ1=0
Proof of (1)
JS(PΓ||QΓ)=0, since PΓX,γ=Q[X,Λ=γ]=QΓX,γ and PΓΛ|X=PΛ|X
Let dx:=δ(PΛ|x1,PΛ|x2),ax:=δ(PΛ|x,PΛ|x1), and bx:=δ(PΛ|x,PΛ|x2)
δ(Q,S)=√JS(Q,S)=√EPXJS(PΛ|X1||PΛ|X2)=√EPX(dX)2≤√EPX(aX+bX)2 by the triangle inequality of metric δ≤√EPX(aX)2+√EPX(bX)2 via the Minkowski Ineqality=√JS(P||Q)+√JS(P||S)=√ϵ1+√ϵ2
Proof of (2)
√ϵΓ2=√JS(PΓ||SΓ)=√JS(Q||R)=:δ(Q,R)
δ(Q,R)≤δ(Q,S)+δ(S,R) by the triangle inequality of metric δ≤δ(Q,R)+√ϵ1 by Lemma 1≤2√ϵ1+√ϵ2 by Lemma 2
■
(3) ϵΓmed≤(2√ϵ1+√ϵmed)2
Proof of (3)
JS(M||MΓ)=∑γP[Λ=γ]JS(P[X2|Λ=γ]||R[X2|Λ=γ])=EPΛJS(SX2|Λ||RX2|Λ)≤JS(S||R) by the Data Processing Inequality
√ϵΓmed=δ(PΓ,MΓ)=δ(Q,MΓ)≤δ(Q,P)+δ(P,M)+δ(M,MΓ) by the triangle inequality of metric δ=√ϵ1+√ϵmed+√JS(M,MΓ)≤√ϵ1+√ϵmed+√JS(M,MΓ)≤2√ϵ1+√ϵmed by Lemma 1
■
Results
So, as shown above, (using Jensen-Shannon Divergence as the error function,) resampling any latent variable according to either one of its redundancy diagrams (just swap ϵ1 and ϵ2 for the bounds when resampling from X2) produces a new latent variable which satisfies the redundancy and mediation diagrams approximately as well as the original, and satisfies one of the redundancy diagrams perfectly.
The bounds are: ϵΓ1=0ϵΓ2≤(2√ϵ1+√ϵ2)2ϵΓmed≤(2√ϵ1+√ϵmed)2
Where the epsilons without superscripts are the errors corresponding to factorization via the respective naturality conditions of the original latent Λ and X.
Bonus
For a,b>0, (2√a+√b)2≤5(a+b) by Cauchy-Schwartz with vectors [2,1],[√a,√b]Thus the simpler, though looser, bound: max{ϵΓ1,ϵΓ2,ϵΓmed}≤5(ϵ1+ϵ2+ϵmed)
[Edit: Here is a collab session where I numerically tested the bounds on a system with 3 binary variables, both with random sampling and with a simple gradient ascent test aiming to break the bound. All numerical checks passed.]
Resampling Conserves Redundancy & Mediation (Approximately) Under the Jensen-Shannon Divergence
Around two months ago, John and I published Resampling Conserves Redundancy (Approximately). Fortunately, about two weeks ago, Jeremy Gillen and Alfred Harwood showed us that we were wrong.
This proof achieves, using the Jensen-Shannon divergence (“JS”), what the previous one failed to show using KL divergence (“DKL”). In fact, while the previous attempt tried to show only that redundancy is conserved (in terms of DKL) upon resampling latents, this proof shows that the redundancy and mediation conditions are conserved (in terms of JS).
Why Jensen-Shannon?
In just about all of our previous work, we have used DKL as our factorization error. (The error meant to capture the extent to which a given distribution fails to factor according to some graphical structure.) In this post I use the Jensen Shannon divergence.
DKL(U||V):=EUlnUV
JS(U||V):=12DKL(U||U+V2)+12DKL(V||U+V2)
The KL divergence is a pretty fundamental quantity in information theory, and is used all over the place. (JS is usually defined in terms of DKL, as above.) We have pretty strong intuitions about what DKL means and it has lots of nice properties which I won’t go into detail about, but we have considered it a strong default when trying to quantify the extent to which two distributions differ.
The JS divergence looks somewhat ad-hoc by comparison. It also has some nice mathematical properties (its square root is a metric, a feature sorely lacking from DKL) and there is some reason to like it intuitively: JS(U||V) is equivalent to the mutual information between X, a variable randomly sampled from one of the distributions, and Z, an indicator which determines the distribution X gets sampled from. So in this sense it captures the extent to which a sample distinguishes between the two distributions.
Ultimately, though, we want a more solid justification for our choice of error function going forward.
This proof works, but it uses JS rather than DKL. Is that a problem? Can/Should we switch everything over to JS? We aren’t sure. Some of our focus for immediate next steps is going to be on how to better determine the “right” error function for comparing distributions for the purpose of working with (natural) latents.
And now, for the proof:
Definitions
Let P be any distribution over X and Λ.
I will omit the subscripts if the distribution at hand is the full joint distribution with all variables unbound. I.e.PX,Λ is the same as P. When variables are bound, they will be written as lower case in the subscript. When this is still ambiguous, the full bracket notation will be used.
First, define auxiliary distributions Q, S, R, and M:
Q:=PXPΛ|X1, S:=PXPΛ|X2, R:=PXQΛ|X2=PX∑X1[PX1|X2PΛ|X1], M:=PΛPX1|ΛPX2|Λ
Q, S, and M each perfectly satisfy one of the (stochastic) Natural Latent conditions, with Q and S each satisfying one of the redundancy conditions (X2→X1→Λ, and X1→X2→Λ, respectively,) and M satisfying the mediation condition (X1←Λ→X2).
R represents the distribution when both of the redundancy factorizations are applied in series to P.
Let Γ be a latent variable defined by P[Γ=γ|X]:=P[Λ=γ|X1]=P[Γ=γ|X1], with PΓ:=PX,ΛPΓ|X
Now, define the auxiliary distributions QΓ, SΓ, and MΓ, similarly as above, and show some useful relationships to P, Q, S, R, and M:
QΓX,γ:=PXPΓγ|X1=PXQ[Λ=γ|X1]=Q[X,Λ=γ]SΓX,γ:=PXPΓγ|X2=PX∑X1(PX1|X2Pγ|X1)=R[X,Λ=γ], MΓX,γ:=PΓγPΓX1|γPΓX2|Γ=P[Λ=γ]P[X1|Λ=γ]R[X2|Λ=γ]
PΓX,γ=PXPγ|X=Q[X,Λ=γ] PΓγ=Q[Λ=γ]=P[Λ=γ]=PΓ[Λ=γ] PΓX1|γ=P[X1|Λ=γ]=Q[X1|Λ=γ] PΓX2|γ=R[X2,Λ=γ]PΓγ=R[X2|Λ=γ]
Next, the error metric and the errors of interest:
Jensen-Shannon Divergence, and Jensen-Shannon Distance (a true metric):
JS(U||V):=12DKL(U||U+V2)+12DKL(V||U+V2)
δ(U,V):=√JS(U||V)=δ(V,U)
ϵ1:=JS(P||Q),ϵ2:=JS(P||S),ϵmed:=JS(P||M)
ϵΓ1:=JS(PΓ||QΓ),ϵΓ2:=JS(PΓ||SΓ)=JS(Q||R),ϵΓmed:=JS(PΓ||MΓ)=JS(Q||MΓ)
Theorem
Finally, the theorem:
For any distribution P over (X, Λ), the latent Γ∼P[Λ|Xi] has redundancy error of zero on one of it’s factorizations, while the other factorization errors are bounded by small factor of the errors induced by Λ. More formally:
∀P[X,Λ], the latent Γ defined by P[Γ=γ|X]:=P[Λ|X1] has bounded factorization errors ϵΓ1=0 and max(ϵΓ2,ϵΓmed)≤5(ϵ1+ϵ2+ϵmed).
In fact, that is a simpler but looser bound than that proven below which achieves the more bespoke bounds of: ϵΓ1=0, ϵΓ2≤(2√ϵ1+√ϵ2)2, and ϵΓmed≤(2√ϵ1+√ϵmed)2.
Proof
(1) ϵΓ1=0
Proof of (1)
JS(PΓ||QΓ)=0, since PΓX,γ=Q[X,Λ=γ]=QΓX,γ and PΓΛ|X=PΛ|X
■
(2) ϵΓ2≤(2√ϵ1+√ϵ2)2
Lemma 1: JS(S||R)≤ϵ1
S[Λ|X2]=P[Λ|X2]=∑X1P[X1|X2]P[Λ|X]
R[Λ|X2]=Q[Λ|X2]=∑X1P[X1|X2]P[Λ|X1]
JS(S||R)=∑X2JS(SΛ|X2||RΛ|X2)≤∑XP[X2]P[X1|X2]JS(PΛ|X||P[Λ|X1])=JS(P||Q)=:ϵ1[1]
Lemma 2: δ(Q,R)≤√ϵ1+√ϵ2
Let dx:=δ(PΛ|x1,PΛ|x2),ax:=δ(PΛ|x,PΛ|x1), and bx:=δ(PΛ|x,PΛ|x2)
δ(Q,S)=√JS(Q,S)=√EPXJS(PΛ|X1||PΛ|X2)=√EPX(dX)2≤√EPX(aX+bX)2 by the triangle inequality of metric δ≤√EPX(aX)2+√EPX(bX)2 via the Minkowski Ineqality=√JS(P||Q)+√JS(P||S)=√ϵ1+√ϵ2
Proof of (2)
√ϵΓ2=√JS(PΓ||SΓ)=√JS(Q||R)=:δ(Q,R)
δ(Q,R)≤δ(Q,S)+δ(S,R) by the triangle inequality of metric δ≤δ(Q,R)+√ϵ1 by Lemma 1≤2√ϵ1+√ϵ2 by Lemma 2
■
(3) ϵΓmed≤(2√ϵ1+√ϵmed)2
Proof of (3)
JS(M||MΓ)=∑γP[Λ=γ]JS(P[X2|Λ=γ]||R[X2|Λ=γ])=EPΛJS(SX2|Λ||RX2|Λ)≤JS(S||R) by the Data Processing Inequality
√ϵΓmed=δ(PΓ,MΓ)=δ(Q,MΓ)≤δ(Q,P)+δ(P,M)+δ(M,MΓ) by the triangle inequality of metric δ=√ϵ1+√ϵmed+√JS(M,MΓ)≤√ϵ1+√ϵmed+√JS(M,MΓ)≤2√ϵ1+√ϵmed by Lemma 1
■
Results
So, as shown above, (using Jensen-Shannon Divergence as the error function,) resampling any latent variable according to either one of its redundancy diagrams (just swap ϵ1 and ϵ2 for the bounds when resampling from X2) produces a new latent variable which satisfies the redundancy and mediation diagrams approximately as well as the original, and satisfies one of the redundancy diagrams perfectly.
The bounds are:
ϵΓ1=0ϵΓ2≤(2√ϵ1+√ϵ2)2ϵΓmed≤(2√ϵ1+√ϵmed)2
Where the epsilons without superscripts are the errors corresponding to factorization via the respective naturality conditions of the original latent Λ and X.
Bonus
For a,b>0, (2√a+√b)2≤5(a+b) by Cauchy-Schwartz with vectors [2,1],[√a,√b]Thus the simpler, though looser, bound: max{ϵΓ1,ϵΓ2,ϵΓmed}≤5(ϵ1+ϵ2+ϵmed)
[Edit: Here is a collab session where I numerically tested the bounds on a system with 3 binary variables, both with random sampling and with a simple gradient ascent test aiming to break the bound. All numerical checks passed.]
The joint convexity of JS(U||V), which justifies this inequality, is inherited from the joint convexity of KL Divergence.