Firth Bias Reduction in Few-Shot Learning

On the Importance of Firth Bias Reduction in Few-Shot Classification
ICLR 2022 Spotlight
Thomas M. Siebel Center for Computer Science
National Center for Super-computing Applications
University of Illinois Urbana-Champaign
University of Illinois Urbana-Champaign
* Denotes Equal Contribution
Paper
<\Code>
Data

The Premise

Every time you train a logistic classifier in a few-shot task and minimize a cross-entropy loss, you're essentially performing a Maximum Likelihood Estimation (MLE).

In the world of statistics, MLEs are known to suffer from serious bias issues when only few samples are provided.

The few shots you use are random samples, so the MLEs are going to be random variables too. There is no easy way around this stochasticity and you would always get some randomness or variance in the resulting MLE.

That being said, it's only reasonable to hope that you'd get the right parameters with MLE, at least on average. Well, that's the issue! Not only MLEs can have a lot of variance, but also they can be severely off even on-average!

A conceptual visualization of the classifier parameter estimation's bias problem in few-shot learning.
A conceptual visualization of the classifier parameter estimation's bias problem in few-shot learning.

Care for a simple example to see the severity of this issue?

Here is a simple toy-example show-casing this issue in a few-shot geometric experiment with a fair coin (yes; the same exact problem from your introductory probability course).

You want to recover the coin head probability with a few number of samples, so you use the MLE. However, you're curious if you'd even get the right parameter on average. To check that, you simulate some experiments in python and plot the average MLE. Here's what you'll see:

The Average MLE in a Geometric Experiment

drawing

The blue points show the average MLE for various sample sizes. The black double-sided vertical arrow shows the MLE bias away from the true parameter. The red points show a slightly better estimator than MLE.

The MLE Bias vs. the Sample Size in a log-log Scale

drawing

The vertical axis shows the log-10 of the MLE bias, and the horizontal axis shows the log-10 of the number of samples. The bias isn't going away as fast as you would hope so; it's of O(N1)O(N^{-1}) and far from an exponential drop for sure.

This begs the central question in our paper:

If MLEs cannot recover the true parameter even on average in such an easy problem, then how can we trust they're best for few-shot logistic classifiers with thousands of dimensions?

This motivates the introduction of the Firth MLE penalty, a glimpse of which was shown in the geometric example plot in red.

Firth Bias Reduction in Few Words

For 1-Layer Logistic and Cosine Classifiers with the Cross-Entropy Loss:

All you need to do, is replace

β^=argminβ1Ni=1N[CE(Pi,yi)]\hat{\beta} = \text{argmin}_{\beta} \quad \frac{1}{N}\sum_{i=1}^{N} \bigg[\text{CE}(\mathbf{P}_i, \mathbf{y}_i)\bigg]

with

β^Firth=argminβ1Ni=1N[CE(Pi,yi)+λCE(Pi,U)]\hat{\beta}_{\text{Firth}} = \text{argmin}_{\beta} \quad \frac{1}{N}\sum_{i=1}^{N} \bigg[\text{CE}(\mathbf{P}_i, \mathbf{y}_i) + \lambda \cdot \text{CE}(\mathbf{P}_i,\mathbf{U}) \bigg]

where U\mathbf{U} is the uniform distribution over the classes, and λ\lambda is a positive constant. The CE-term with the uniform distribution is basically the (negative) sum of the prediction log-probability values over all data points and classes.

Our paper provides a theoretical proof of why the added penalty is a simplification of a log(det(F))\log(\det(F)) term (thereby, encouraging "larger" Fisher information).

General Firth Bias Reduction Form:

Add a log-det of FIM term to your loss minimization problem. That is, replace

β^=argminβ[l(β)]\hat{\beta}=\text{argmin}_{\beta}\quad\bigg[l(\beta)\bigg]

with

β^Firth=argminβ[l(β)+λlog(det(F))]\hat{\beta}_{\text{Firth}} = \text{argmin}_{\beta} \quad \bigg[l(\beta) + \lambda\cdot \log(\det(F))\bigg]

This was proven to reduce the bias of your estimated parameters in Firth's original work

Experiments and Results

Logistic Classifiers and Basic Feature Backbones

The following is the effect of Firth bias reduction compared to typical L2 regularization in 16-way few-shot classification tasks using basic feature backbones and 1-layer logistic classifiers. The vertical axis shows the accuracy improvements, and the horizontal axis shows the number of shots.

drawing
drawing

Here's the same set of results, but with 3-layer logistic classifiers (instead of 1-layer networks).

drawing
drawing

Cosine Classifiers and S2M2R Feature Backbones

Below is the effect of Firth bias reduction on cosine classifiers and S2M2R features. The horizontal axis is the number of classes, and the vertical axis shows the Firth accuracy improvements.

drawing
drawing
drawing

Firth Bias Reduction on the Distribution Calibration Method

The following shows the recent state of the art method of few-shot Distribution Calibration (DC) in cross-domain settings with and without Firth bias reduction. Each setting was tested with and without data augmentation (addition of 750 samples), and the maximum accuracy was reported. Note that the confidence intervals are much smaller for the improvement column, thanks to the random-effect matching procedure we used in this study.

mini → CUBtiered → CUB
WayShotBeforeAfterImprovementBeforeAfterImprovement
10137.14 ± 0.1237.41 ± 0.120.27 ± 0.0364.36 ± 0.1664.52 ± 0.160.15 ± 0.03
10559.77 ± 0.1260.77 ± 0.121.00 ± 0.0486.23 ± 0.1086.66 ± 0.090.43 ± 0.03
15130.22 ± 0.0930.37 ± 0.090.15 ± 0.0357.73 ± 0.1357.73 ± 0.130.00 ± 0.00
15552.73 ± 0.0953.84 ± 0.091.11 ± 0.0382.16 ± 0.0983.05 ± 0.080.89 ± 0.03

Implementation

Implementing Firth bias reduction for 1-layer logistic and cosine classifiers only takes one or two extra lines of code.

ce_loss = nn.CrossEntropyLoss()
ce_term = ce_loss(logits, target)

# This is how you can compute the Firth bias reduction from classifier logits
log_probs = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
firth_term = -log_probs.mean()

loss = ce_term + lam * firth_term
loss.backward()

Alternatively, you can use the label_smoothing keyword argument in nn.CrossEntropyLoss. Remember that this Firth formulation is only true for 1-layer logistic and cosine classifiers. For more complex networks, the FIM's log-determinant must be worked out.

As for the λ\lambda coefficient,

  • Firth's original work set it to a pre-determined constant.

  • Recently, logF(m,m)\log F(m,m) models proposed scaling Firth's pre-determined coefficient, making λ\lambda a hyper-parameter.

  • We followed the common machine learning practice, and validated the λ\lambda coefficient on the validation split, then evaluated the validated λ\lambda on the novel set.

  • You don't need much resolution for the validation search; we performed the λ\lambda search in a log-10 space on a handful of candidates (λ{0.0,0.01,0.03,0.1,0.3,1.0,3.0,10.0}\lambda\in \{0.0, 0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0\}).

Code

Our implementation is open-source and available at https://github.com/ehsansaleh/firth_bias_reduction.

Due to the volume of experimental settings in our paper, we broke down the code into three sub-modules:

  • The code_firth repository corresponds to the Firth bias reduction experiments using standard ResNet architectures and logistic classifiers (e.g., Figure 2 and 3 in the main paper).

  • The code_s2m2rf repository corresponds to the experiments with cosine classifiers on WideResNet-28 feature stacks trained by the S2M2R method.

  • The code_dcf repository contains our GPU implementation of the Distribution Calibration (DC) method and the relevant Firth bias reduction improvements.

All of them are standalone repositories with

  • detailed documentation in their corresponding readme files, and

  • helper scripts for automatically downloading and extracting the features, datasets, and backbone parameters from the external sources (such as Google Drive).

You can clone all three modules with the following command:

git clone --recursive https://github.com/ehsansaleh/firth_bias_reduction.git

Data

We have published all the data, pre-computed features, trained backbone parameters, and other auxiliary files (complimenting the open-source code) in two redundant external sources; the Illinois Data Bank and Google-Drive.

Illinois Data Bank

We have included our pre-computed features and and trained backbones in tar-ball archives in our Illinois Data Bank Repository at https://doi.org/10.13012/B2IDB-1016367_V1 with brief instructions for manually downloading and placing the data.

Google Drive

References

Here is the bibtex citation entry for our work:

@inproceedings{ghaffari2022fslfirth,
    title={On the Importance of Firth Bias Reduction In Few-Shot Classification},
    author={Saba Ghaffari and Ehsan Saleh and David Forsyth and Yu-Xiong Wang},
    booktitle={International Conference on Learning Representations},
    year={2022},
    url={https://openreview.net/forum?id=DNRADop4ksB}
}