Right now in your code, you only calculate reconstruction error gradients for the very last step.
if random.random() > delta:
loss = loss + (probs * err).sum(dim=1).mean()
break
Pragmatically, it is more efficient to calculate reconstruction error gradients at every step and just weight by the probability of being the final image:
loss = loss + (1 - delta) * (probs * err).sum(dim=1).mean()
if random.random() > delta:
break
Right now in your code, you only calculate reconstruction error gradients for the very last step.
Pragmatically, it is more efficient to calculate reconstruction error gradients at every step and just weight by the probability of being the final image: