Ali-Gombe, A., & Elyan, E. (2019). MFC-GAN: class-imbalanced dataset classification using multiple fake class generative adversarial network. Neurocomputing, 361, 212-221.
In Short
Data Geneation(Augmentation) with Multiclass Fakes
1. Introduction
Since both minority and majority classes come from the same distribution, these classes share some common features.
=> Features learned from majority classes should aid in learning the minority classes.
=> Class conditioned generation will focus the model into sampling minority classes.
2. Related Works
2-1. Mutliclass์ ๋ํ ๊ด์ฌ ๋ถ์กฑ
binary classification์ ๋ํด์๋ ์ฐ๊ตฌ๊ฐ ๋ง์ด ์ด๋ฃจ์ด์ก์ง๋ง, ์ด์ ๋นํด multiclass classification์๋ ์๋์ ์ผ๋ก ๊ด์ฌ์ด ์ ์๋ค. ๊ทธ๋ผ์๋ ์ด์ ๊ด๋ จ๋ ์ ํ์ฐ๊ตฌ๋ค์ด ์๊ธด ํ๋ฐ, ์๋ฅผ ๋ค์ด์ multi-class decomposition, Class Rectification Loss(CRL), mean squared false error ๋ฑ์ด ์๋ค.
- CRL: it performs hard mining of the minority class is each each batch forcing the model to create a boundary for each minority class with a hard positive and negative threshold.
- LMLE(Large Margin Local Embedding): it employs clustering among classes to maintain the structure of the minroity data.
=> ํ์ง๋ง ์์ ๋ ๋ฐฉ๋ฒ๋ก ์ ๋ฐ์ดํฐ๊ฐ ํด ๊ฒฝ์ฐ ๊ณ์ฐ๋์ด ๋๋ฌด ๋ง์์ง๋ค๋ ๋จ์ ๋ค์ด ์๋ค.
2-2. SMOTE
SMOTE์ ๊ฐ์ด ๊ธฐ์กด์ ์ ๋ช ํ ๋ฐฉ์๋ค์ ๊ทน๋จ์ ์ผ๋ก ๋ถ๊ท ํ์ธ ๋ฐ์ดํฐ์ ๋ํด์๋ ๋นํจ์จ์ ์ด๋ผ๊ณ ์๋ ค์ ธ์๋ค.
2-3. GAN์ ๋ฐ์
minority class data๋ฅผ ์์ฑํ๊ธฐ ์ํด์ C-GAN(Conditional GAN)์ด ์ ์ฌ์ ํด๊ฒฐ๋ฐฉ์์ด ๋ ์๋ ์์ง๋ง, ์๋ฒฝํ์ง๋ ์๋ค. ์ด์ฒ๋ผ vanilla GAN์ด๋ AC-GAN๋ค๋ ๋ฐ์ดํฐ๊ฐ ๋ถ๊ท ํํ ์ํ์์๋ minority sample์ ์์ฑํ๋ ๋ฐ์ ์์ด์ ํจ๊ณผ์ ์ด์ง ๋ชปํ๋ค. DCGAN์ด๋ MelanoGAN์ ํจ๊ณผ์ ์ด๊ธฐ๋ ํ์ผ๋, multiclass๊ฐ ์๋ binary case์์์๋ค.
3. Method
GAN์ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ ๊ฐ์ง ์ข ๋ฅ์ ํ์ต๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ๋ค. ํ๋๋ ๊ธฐ์กด์ ํ์ต๋ฐ์ดํฐ์ ์๋ ์๋ณธ๋ฐ์ดํฐ์ด๊ณ , ๋๋จธ์ง ํ๋๋ Generator๊ฐ ์์ฑํ ์ํ๋ค(fake)์ด๋ค.
10๊ฐ์ ํด๋์ค๊ฐ ์๋ค๊ณ ํ์ ๋, ์๋๋ one-hot encoding ๋ฐฉ์์ผ๋ก 1000000000๋ก ํ์ํ๋ค. ํ์ง๋ง ์ฌ๊ธฐ์๋ fake image์ ๋ํ label๋ ๋ฐ๋ก ์ ๊ฒฝ์จ์ค๋ค. ๊ทธ๋ ๊ธฐ ๋๋ฌธ์ real image๋ 10000000000000000000๋ก ํ์ํ๊ณ , ์ด์ ํด๋นํ๋ fake image์ label์ 00000000001000000000๋ก ํ์ํ๋ค.
3-1. Objective Function
$L_s$: to estimate the sampling loss, which represents the probability of the sample being real or fake
$L_{cd}$: to estimate the classification loss over the discriminator
$L_{cg}$: to estimate the classification loss over the generator
$L_{cd}$ means that the discriminator classifies samples as real or fake with associated class
$L_{cg}$ means that the generator classifies fake samples as real classes
generator: to maximize the difference of $L_s$ and $L_{cg}$
discriminator: to maximize the sum of $L_s$ and $L_{cd}$
MFC-GAN generator is sampled using a noise vector conditioned on real class labels
label์ด ์๋ ๊ฒฝ์ฐ์๋ ์์ ๊ฐ์ด Vanilla GAN์ฒ๋ผ ์๋ํ๊ฒ ๋๋ค.
๋ณธ ๋ ผ๋ฌธ์ ํต์ฌ์ discriminator์ด๋ค. ๊ฐ๋ น, AC-GAN์ discriminator๋
$L_{cd}$๊ฐ ์๋๋ผ $L_{cg}$์ $L_s$์ ํฉ์ ์ต๋ํํ๋ ๊ฒ์ด ๋ชฉํ์ด๋ค.3-2. MFC-GAN vs. FSC-GAN
MFC-GAN: generator is penalized according to how far the generated sample is from the real class label
FSC-GAN: generator is penalized according to how far the generated sample is from the fake class label
=> this promoted early convergence of MFC-GAN
4. Experiments
4-1. Experimental set-up
- ๋ผ์ด๋ธ๋ฌ๋ฆฌ:
tensorflow 1.0,Keras 2.0 - ๋น๊ต๋ชจ๋ธ: SMOTE, AC-GAN, FSC-GAN, ์๋ฐ์ดํฐ
- ๊ณตํต ๋ถ๋ฅ๋ถ์๊ธฐ: CNN
- ๊ฒฐ๊ณผ๋น๊ต๊ธฐ์ค:
- ์ฃผ๊ด์ : plausibility of sample(i.e., visual inspection)
- ๊ฐ๊ด์ : ๋ค์ํ ์งํ๋ค์ ํตํด ๋ถ๋ฅ๊ฒฐ๊ณผ ์ฑ๋ฅ ๋น๊ต
4-2. Dataset
๋ฐ์ดํฐ: MNIST, E-MNIST, SVHN, CIFAR-10
- ๋๋ถ๋ถ ๊ฐ์ ๋ก imbalanced data๋ก ๋ง๋ค์ด์ฃผ๊ณ ๋์ ํ์ฉํจ(ํน์ ํด๋์ค ์์๋ก ์ ํ ํ undersampling)
MNIST: 2๊ฐ์ ํด๋์ค๋ฅผ ์์๋ก ๊ณจ๋ผ ์๋ฐ์ดํฐ์ 1%(50๊ฐ)์ฉ๋ง ํ์ฉ (run๋ง๋ค ๋ค๋ฅธ ํด๋์ค ์ ํ)E-MNIST: ์ด 81๋ง ์ฌ๊ฐ ์ํ์์, 62๊ฐ ํด๋์ค ์ค 21๊ฐ์ ํด๋์ค์ ํด๋นํ๋ ์ํ๋ค์ด ๊ฐ๊ฐ 3000๊ฐ ์ดํ์ด๋ค. ์ฆ, ์ด๋ฏธ imbalanced์ด๋ค. ์ฌ๊ธฐ์ ๊ฐ์ฅ ์ ์ 10๊ฐ์ ํด๋์ค๋ฅผ ํ์ฉํ์๋ค. (G, K, Q, f, j, k, m, p, s, y)SVHN: MNIST์ฒ๋ผ 2๊ฐ์ ํด๋์ค์์ 50๊ฐ์ฉ๋ง ์ฌ์ฉํ์ฌ ๊ฐ์ ๋ก imbalanced๋ฅผ ๋ง๋ค์ด์ฃผ์๋ค. (1,2)CIFAR-10: 10๊ฐ ํด๋์ค ์ค ๋ ๊ฐ์ ํด๋์ค์์ 1000๊ฐ ์ค์์ 50๊ฐ์ฉ๋ง ํ์ฉํ์๋ค. ์ฆ, ๊ฐ์ ๋ก imbalanced๋ฅผ ๋ง๋ค์ด์ฃผ์๋ค. (Aeroplane, Automobile)
MNIST์ E-MNIST์ ๊ฒฝ์ฐ์๋ FSC-GAN์ ๊ตฌ์กฐ๋ฅผ ์ฐจ์ฉํ์๊ณ , SVHN๊ณผ CIFAR-10์์๋ AC-GAN์ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ์๋ค. ๊ทธ๋ฆฌ๊ณ , AC-GAN, FSC-GAN, MFC-GAN ๋ชจ๋์ ๊ฒฝ์ฐ์ ๊ณต์ ํ ๋น๊ต๋ฅผ ์ํด spectral weight normalization์ ์์ฑ๊ธฐ์ ๋ถ๋ฅ๊ธฐ ๋ชจ๋์ ์ถ๊ฐํด์ฃผ์๋ค.
4-3. Image Generation
- network switcher์ ์๋ฏธ
Depending on the availability of labels, the network switcher feature enables both models to alternate between two training modes. The switcher is a piece-wise function that oscillates between supervised and unsupervised training.
4-4. Image Classification
CNN์ ๊ณตํต ๋ถ๋ฅ๋ถ์๊ธฐ๋ก ํ์ฉํ์๋ค.
- MNIST, E-MNIST, SVHN: 3 layers with softmax activation layer (2 convolution layers, 3x3 kernels with 2x2 max-pooling, two filter maps, fully connected layer), ReLuactivated, 0.5 dropout ratio, Adadelta optimizer
- CIFAR-10: 3 convolution layers, 0.2 dropout ratio, SGD optimizer
5. Results
-
์์ฑ๋ ์ฌ์ง์ ํ๋ฆฌํฐ๊ฐ ์ฐ์ํจ
๋ฑ ๋ณด๋๋ผ๋ MFC-GAN์ด ํจ์ฌ ์ฐ์ํ ์ฑ๋ฅ์ผ๋ก ์ด๋ฏธ์ง๋ฅผ ๋ง๋ค์ด๋์ ๋์ผ๋ก ํ์ธํ ์ ์๋ค. ํนํ (c)๋ฅผ ๋ณด๋ฉด, unlabeled๊ฐ 5๋ง๊ฐ์๋๋ฐ๋ ์ฑ๋ฅ์ด ๊ด์ฐฎ๊ฒ ๋์๋ค. -
์ปดํจํ ํจ์จ์ฑ
FSC-GAN์ 500 epoch๊ฐ ํ์ํ ๋ฐ์ ๋นํด, MFC-GAN์ 50 epoch๋ง์ ํ์๋ก ํ๋ค. ์ฆ, data augmentation์๋ MFC-GAN์ด ๋ณด๋ค ์ฐ์ํจ์ ์ ์ ์์๋ค. -
๊ฐ๊ด์ ์ฑ๋ฅ์งํ์์๋ ์ฐ์ํจ
MFC-GAN์์ ๋ฏผ๊ฐ๋, balanced accuracy, G-mean ๋ชจ๋ ๋๊ฒ ๋ํ๋ฌ๋ค. ์ด์ ๋ฐํด,FSC-GAN์ ๋ชจ๋ ๊ฒฝ์ฐ์์ ์ฑ๋ฅ ํฅ์์ผ๋ก ์ด์ด์ง์ง ์์๋ค. ํน์ ์ํฉ์์๋SMOTE๊ฐ ๋ ์ฐ์ํ ๋ฏ ๋ณด์ด๊ธฐ๋ ํ์ง๋ง, ์ด๋ ๋จ์ํ ์์ ์ง๋จ์ ๋ํด ์ํ์ด ๋ ๋ง๊ธฐ ๋๋ฌธ์ผ ๊ฑฐ๋ผ๊ณ ์ถ์ธก๋๋ค. ์๋ํ๋ฉด ์ํ ์๊ฐ ์ ์ ํด๋์ค์ ๋ํด์๋ ์ฑ๋ฅ์ด ์ ์ข์ ๊ฒ์ด ํ์ธ๋์๊ธฐ ๋๋ฌธ์ด๋ค.
6. Discussion
The fidelity and diversity of MFC-GAN minority samples made classification easier for the CNN. The diversity of generated samples indicates no sign of mode collapse in the model.
ํ๊ณ์ 1. CIFAR-10์์์ ๋ถ์กฑํ ์ฑ๋ฅ
Table3์์ ์ ์ ์๋ฏ์ด, CIFAR-10์์๋ ๋ชจ๋ ๋ชจ๋ธ๋ค์ด ์ฑ๋ฅ์ด ์ข์ง ์๊ฒ ๋ํ๋ฌ๋ค. ์ด๋ 32x32๋ผ๋ ์์ ์ด๋ฏธ์ง ์ฌ์ด์ฆ์ ๋นํด ๋ง์ ์ ๋ณด๋์ด ๋ค์ด๊ฐ์๋ ๋ถ๋ฅ์์
์ด๊ธฐ ๋๋ฌธ์ด๋ผ๊ณ ๋ณด์ธ๋ค. (๊ทธ๋๋ ๊ทธ์ค์์ ๊ทธ๋๋ง MFC-GAN์ด ๋ซ๊ธฐ๋ ํ๋ค.)
ํ๊ณ์ 2. ํน์ ํด๋์ค์์์ ๋ถ์กฑํ ์ฑ๋ฅ
ํน์ ํด๋์ค์ ๋ํด์๋ ๋ชจ๋ ๋ชจ๋ธ๋ค์ด ๋ชจ๋ ์ฑ๋ฅ์ด ์ ์ข๊ฒ ๋ํ๋ฌ๋ค. E-MNIST์ ๊ฒฝ์ฐ์๋ m๊ณผ s๊ฐ ์ด์ ํด๋นํ๋ค. ์ด๋ s๊ฐ 5, S, 2, z์ ๋น์ทํ๊ธฐ ๋๋ฌธ์ด๋ผ๊ณ ์ ์๋ค์ ๋ณด๊ณ ์๋ค.
—
Critical Point (MY OWN OPINION)
-
E-MINST, SVHN, CIFAR-10์์ ๋ค๋ฅธ ํด๋์ค๊ฐ minority class๋ก์ ์ ์ ์ด ๋์๋ค๋ฉด ์ด๋ป๊ฒ ๊ฒฐ๊ณผ๊ฐ ๋ฌ๋ผ์ก์์ง ์ ์๊ฐ ์๋ค. ์ ํ๋ ํด๋์ค์ ํน์ฑ์ ๋ฐ๋ฅธ ๊ฒฐ๊ณผ๊ฐ ์๋๋ผ๊ณ ๊ฒฐ๋ก ์ง์ ์๊ฐ ์๋ค๋ ํ๊ณ์ ์ด ์๋ค.
-
๊ธฐ๋ณธ CNN ๋ชจํ ๋ง๊ณ ResNet๋ YOLO์ ๊ฐ์ด ๋ค๋ฅธ architecture๋ฅผ ์ฌ์ฉํ์ผ๋ฉด ์ด๋ ์๊น๋ผ๋ ์๊ฐ์ด ์๋ค.
—
์ฐธ๊ณ ์ฌ์ดํธ
[1] Spectral Normalization(1)
[2] Spectral Normalization(2)
Further Study Needed
- [1] Few-Shot Classifier GAN (FSC-GAN)
Ali-Gombe, A., Elyan, E., Savoye, Y., & Jayne, C. (2018, July). Few-shot classifier GAN. In 2018 International Joint Conference on Neural Networks (IJCNN) (pp. 1-8). IEEE. - [2] Advanced GANs
์ฌ์ดํธ1
์ฌ์ดํธ2