Every time you train a logistic classifier in a few-shot task and minimize a cross-entropy loss, you're essentially performing a Maximum Likelihood Estimation (MLE).
In the world of statistics, MLEs are known to suffer from serious bias issues when only few samples are provided.
The few shots you use are random samples, so the MLEs are going to be random variables too. There is no easy way around this stochasticity and you would always get some randomness or variance in the resulting MLE.
That being said, it's only reasonable to hope that you'd get the right parameters with MLE, at least on average. Well, that's the issue! Not only MLEs can have a lot of variance, but also they can be severely off even on-average!
Care for a simple example to see the severity of this issue?
Here is a simple toy-example show-casing this issue in a few-shot geometric experiment with a fair coin (yes; the same exact problem from your introductory probability course).
You want to recover the coin head probability with a few number of samples, so you use the MLE. However, you're curious if you'd even get the right parameter on average. To check that, you simulate some experiments in python and plot the average MLE. Here's what you'll see:
The blue points show the average MLE for various sample sizes. The black double-sided vertical arrow shows the MLE bias away from the true parameter. The red points show a slightly better estimator than MLE.
The vertical axis shows the log-10 of the MLE bias, and the horizontal axis shows the log-10 of the number of samples. The bias isn't going away as fast as you would hope so; it's of and far from an exponential drop for sure.
This begs the central question in our paper:
All you need to do, is replace
with
where is the uniform distribution over the classes, and is a positive constant. The CE-term with the uniform distribution is basically the (negative) sum of the prediction log-probability values over all data points and classes.
Our paper provides a theoretical proof of why the added penalty is a simplification of a term (thereby, encouraging "larger" Fisher information).
General Firth Bias Reduction Form:
Add a log-det of FIM term to your loss minimization problem. That is, replace
with
This was proven to reduce the bias of your estimated parameters in Firth's original work
The following is the effect of Firth bias reduction compared to typical L2 regularization in 16-way few-shot classification tasks using basic feature backbones and 1-layer logistic classifiers. The vertical axis shows the accuracy improvements, and the horizontal axis shows the number of shots.
Below is the effect of Firth bias reduction on cosine classifiers and S2M2R features. The horizontal axis is the number of classes, and the vertical axis shows the Firth accuracy improvements.
The following shows the recent state of the art method of few-shot Distribution Calibration (DC) in cross-domain settings with and without Firth bias reduction. Each setting was tested with and without data augmentation (addition of 750 samples), and the maximum accuracy was reported. Note that the confidence intervals are much smaller for the improvement column, thanks to the random-effect matching procedure we used in this study.
mini → CUB | tiered → CUB | ||||||
---|---|---|---|---|---|---|---|
Way | Shot | Before | After | Improvement | Before | After | Improvement |
10 | 1 | 37.14 ± 0.12 | 37.41 ± 0.12 | 0.27 ± 0.03 | 64.36 ± 0.16 | 64.52 ± 0.16 | 0.15 ± 0.03 |
10 | 5 | 59.77 ± 0.12 | 60.77 ± 0.12 | 1.00 ± 0.04 | 86.23 ± 0.10 | 86.66 ± 0.09 | 0.43 ± 0.03 |
15 | 1 | 30.22 ± 0.09 | 30.37 ± 0.09 | 0.15 ± 0.03 | 57.73 ± 0.13 | 57.73 ± 0.13 | 0.00 ± 0.00 |
15 | 5 | 52.73 ± 0.09 | 53.84 ± 0.09 | 1.11 ± 0.03 | 82.16 ± 0.09 | 83.05 ± 0.08 | 0.89 ± 0.03 |
Implementing Firth bias reduction for 1-layer logistic and cosine classifiers only takes one or two extra lines of code.
ce_loss = nn.CrossEntropyLoss()
ce_term = ce_loss(logits, target)
# This is how you can compute the Firth bias reduction from classifier logits
log_probs = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
firth_term = -log_probs.mean()
loss = ce_term + lam * firth_term
loss.backward()
Alternatively, you can use the label_smoothing
keyword argument in nn.CrossEntropyLoss
. Remember that this Firth formulation is only true for 1-layer logistic and cosine classifiers. For more complex networks, the FIM's log-determinant must be worked out.
As for the coefficient,
Firth's original work set it to a pre-determined constant.
Recently, models proposed scaling Firth's pre-determined coefficient, making a hyper-parameter.
We followed the common machine learning practice, and validated the coefficient on the validation split, then evaluated the validated on the novel set.
You don't need much resolution for the validation search; we performed the search in a log-10 space on a handful of candidates ().
Our implementation is open-source and available at https://github.com/ehsansaleh/firth_bias_reduction.
Due to the volume of experimental settings in our paper, we broke down the code into three sub-modules:
The code_firth
repository corresponds to the Firth bias reduction experiments using standard ResNet architectures and logistic classifiers (e.g., Figure 2 and 3 in the main paper).
The code_s2m2rf
repository corresponds to the experiments with cosine classifiers on WideResNet-28 feature stacks trained by the S2M2R method.
The code_dcf
repository contains our GPU implementation of the Distribution Calibration (DC) method and the relevant Firth bias reduction improvements.
All of them are standalone repositories with
detailed documentation in their corresponding readme files, and
helper scripts for automatically downloading and extracting the features, datasets, and backbone parameters from the external sources (such as Google Drive).
You can clone all three modules with the following command:
git clone --recursive https://github.com/ehsansaleh/firth_bias_reduction.git
We have published all the data, pre-computed features, trained backbone parameters, and other auxiliary files (complimenting the open-source code) in two redundant external sources; the Illinois Data Bank and Google-Drive.
Illinois Data Bank
We have included our pre-computed features and and trained backbones in tar-ball archives in our Illinois Data Bank Repository at https://doi.org/10.13012/B2IDB-1016367_V1 with brief instructions for manually downloading and placing the data.
Google Drive
Here is the arxiv link to our paper:
The arxiv PDF link: https://arxiv.org/pdf/2110.02529.pdf
The arxiv web-page link: https://arxiv.org/abs/2110.02529
Here is the open-review link to our paper:
The open-review PDF link: https://openreview.net/pdf?id=DNRADop4ksB
The open-review forum link: https://openreview.net/forum?id=DNRADop4ksB
Our paper got a spotlight presentation at ICLR 2022.
We will update here with links to the presentation video and the web-page on iclr.cc
.
Here is the bibtex citation entry for our work:
@inproceedings{ghaffari2022fslfirth,
title={On the Importance of Firth Bias Reduction In Few-Shot Classification},
author={Saba Ghaffari and Ehsan Saleh and David Forsyth and Yu-Xiong Wang},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=DNRADop4ksB}
}