Unofficial implementation of "Deep Bayesian Active Learning with Image Data" by Yarin Gal, Riashat Islam, Zoubin Ghahramani (ICML 2017) using Pytorch.
In this paper, Gal et al. combine recent advances in Bayesian deep learning into the active learning framework in a practical way -- an active learning framework for high dimensional data, a task which has been extremely challenging so far.
By taking advantage of specialised models such as Bayesian convolutional neural network, the proposed technique obtains a significant improvement on existing active learning approaches.
Compare various acquisition functions: Bayesian Active Learning by Disagreement (BALD, Houlsby et al., 2011),Variation Ratios (Freeman, 1965), Max Entropy (Shannon, 1948), Mean STD (Kampffmeyer et al., 2016;Kendall et al., 2015) and baseline Random relying on Bayesian CNN uncertainty with simple image classification benchmark. All acquisition functions are assessed with same model structure:
Convolution-relu-convolution-relu-max pooling-dropout-dense-relu-dropout-dense-softmax
With 32 convolution kernels, 4x4 kernel size, 2x2 pooling, dense layer with 128 units and dropout probabilities 0.25 and 0.5.
All models are trained on MNIST dataset with random initial training set of 20 datapoints and a validation set of 100 points on optimised weight decay. A standard test set of 10K is used and the rest of the points are used as pool set. The test error of each model and each acquisition function is assessed after each acquisition using dropout approximation at test time.
Monte Carlo dropout is used to decide which datapoints to query next. Repeat the acquisition process for 100 times and acquiring 10 points that maximise the functions for each time. (Total acq points=1000)
This repo consists of 4 experiments conducted in the paper which are:
- Comparison of various acquisition functions
- Importance of model uncertainty
- Comparison to current active learning technqiues with image data (Minimum Bayes Risk, MBR)
- Comparison to semi-supervised learning
- Python 3.5 or later
In Ubuntu, you can install Python 3 like this:
$ sudo apt-get install python3 python3-pip
For Windows and MacOS, please refer to https://www.python.org/getit/.
Use pip3 install -r requirements.txt
to install:
- pytest (for testing)
- modAL (modular active learning framework for Pytorch)
- skorch (a scikit-learn wrapper for Pytorch)
- pytorch
- numpy
- matplotlib
- scipy
- scikit-learn
- To run test:
$ pytest
Run
$ python3 main.py --batch_size 128 \
--epochs 50 \
--lr 1e-3 \
--seed 369 \
--experiments 3 \
--dropout_iter 100 \
--query 10 \
--acq_func 0 \
--val_size 100 \
--result_dir result_npy
Or use --help
for more info.
--determ
is set asFalse
by default for Experiment 1, add this to run Experiment 2.--val_size
is set as100
by default for Experiment 1. To run Experiment 4, please set this to5000
.- In this implementation,
acqusition_iterations = dropout_iterations = 100
- For Experiment 3, please refer to
comparison_to_MBR.ipynb
or Google Colab link here.
Number of acquired images to get model error of %: (the lower the better)
Techniques | 10% error (Paper: Keras) | 10% error (Experiment: Pytorch) | 5% error (Paper: Keras) | 5% error (Experiment: Pytorch) |
---|---|---|---|---|
Random (Baseline) | 255 | 250 | 835 | 517 |
Mean STD | 230 | 100 | 695 | 295 |
BALD | 145 | 150 | 335 | 296 |
Var Ratios | 120 | 143 | 295 | 283 |
Max Entropy | 165 | 163 | 355 | 310 |
To further reduce computational time, 2000 random points subset will be used instead of whole pool. These datapoints are selected randomly from pool data points using
np.random.choice(range(len(X_pool)), size=2000, replace=False)
- Random: ~2m 17s
- BALD: ~10m 52s
- Var Ratios: ~10m 58s
- Max Entropy: ~10m 39s
- Mean STD: ~10m 40s
Best 2 models: Mean STD, Var Ratios
This experiment is run on Binary Classification test. (MNIST two digit classification)
Test error on MNIST with 1000 acquired images, using 5000 validation points:
Technique | Test error (Paper: Keras) | Test error (Experiment: Pytorch) |
---|---|---|
Random(Baseline) | 4.66% | 3.73% |
Mean STD | - | 1.81% |
BALD | 1.80% | 1.81% |
Max Entropy | 1.74% | 1.66% |
Var Ratios | 1.64% | 1.57% |
Best 2 models: Var Ratios, Max Entropy