MFC_GAN / Class-imbalanced dataset classification using Multiple Fake Generative Adversarial Network
   5 min read    ์†์ง€์šฐ

Ali-Gombe, A., & Elyan, E. (2019). MFC-GAN: class-imbalanced dataset classification using multiple fake class generative adversarial network. Neurocomputing, 361, 212-221.

In Short

Data Geneation(Augmentation) with Multiclass Fakes

1. Introduction

Since both minority and majority classes come from the same distribution, these classes share some common features.
=> Features learned from majority classes should aid in learning the minority classes.
=> Class conditioned generation will focus the model into sampling minority classes.

2-1. Mutliclass์— ๋Œ€ํ•œ ๊ด€์‹ฌ ๋ถ€์กฑ

binary classification์— ๋Œ€ํ•ด์„œ๋Š” ์—ฐ๊ตฌ๊ฐ€ ๋งŽ์ด ์ด๋ฃจ์–ด์กŒ์ง€๋งŒ, ์ด์— ๋น„ํ•ด multiclass classification์—๋Š” ์ƒ๋Œ€์ ์œผ๋กœ ๊ด€์‹ฌ์ด ์ ์—ˆ๋‹ค. ๊ทธ๋Ÿผ์—๋„ ์ด์™€ ๊ด€๋ จ๋œ ์„ ํ–‰์—ฐ๊ตฌ๋“ค์ด ์žˆ๊ธด ํ•œ๋ฐ, ์˜ˆ๋ฅผ ๋“ค์–ด์„œ multi-class decomposition, Class Rectification Loss(CRL), mean squared false error ๋“ฑ์ด ์žˆ๋‹ค.

  • CRL: it performs hard mining of the minority class is each each batch forcing the model to create a boundary for each minority class with a hard positive and negative threshold.
  • LMLE(Large Margin Local Embedding): it employs clustering among classes to maintain the structure of the minroity data.
    => ํ•˜์ง€๋งŒ ์œ„์˜ ๋‘ ๋ฐฉ๋ฒ•๋ก ์€ ๋ฐ์ดํ„ฐ๊ฐ€ ํด ๊ฒฝ์šฐ ๊ณ„์‚ฐ๋Ÿ‰์ด ๋„ˆ๋ฌด ๋งŽ์•„์ง„๋‹ค๋Š” ๋‹จ์ ๋“ค์ด ์žˆ๋‹ค.

2-2. SMOTE

SMOTE์™€ ๊ฐ™์ด ๊ธฐ์กด์— ์œ ๋ช…ํ•œ ๋ฐฉ์‹๋“ค์€ ๊ทน๋‹จ์ ์œผ๋กœ ๋ถˆ๊ท ํ˜•์ธ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด์„œ๋Š” ๋น„ํšจ์œจ์ ์ด๋ผ๊ณ  ์•Œ๋ ค์ ธ์žˆ๋‹ค.

2-3. GAN์˜ ๋ฐœ์ „

minority class data๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด์„œ C-GAN(Conditional GAN)์ด ์ž ์žฌ์  ํ•ด๊ฒฐ๋ฐฉ์•ˆ์ด ๋  ์ˆ˜๋Š” ์žˆ์ง€๋งŒ, ์™„๋ฒฝํ•˜์ง€๋Š” ์•Š๋‹ค. ์ด์ฒ˜๋Ÿผ vanilla GAN์ด๋‚˜ AC-GAN๋“ค๋„ ๋ฐ์ดํ„ฐ๊ฐ€ ๋ถˆ๊ท ํ˜•ํ•œ ์ƒํƒœ์—์„œ๋Š” minority sample์„ ์ƒ์„ฑํ•˜๋Š” ๋ฐ์— ์žˆ์–ด์„œ ํšจ๊ณผ์ ์ด์ง€ ๋ชปํ–ˆ๋‹ค. DCGAN์ด๋‚˜ MelanoGAN์€ ํšจ๊ณผ์ ์ด๊ธฐ๋Š” ํ–ˆ์œผ๋‚˜, multiclass๊ฐ€ ์•„๋‹Œ binary case์—์„œ์˜€๋‹ค.

3. Method

GAN์€ ๊ธฐ๋ณธ์ ์œผ๋กœ ๋‘ ๊ฐ€์ง€ ์ข…๋ฅ˜์˜ ํ•™์Šต๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ํ•˜๋‚˜๋Š” ๊ธฐ์กด์˜ ํ•™์Šต๋ฐ์ดํ„ฐ์— ์žˆ๋Š” ์›๋ณธ๋ฐ์ดํ„ฐ์ด๊ณ , ๋‚˜๋จธ์ง€ ํ•˜๋‚˜๋Š” Generator๊ฐ€ ์ƒ์„ฑํ•œ ์ƒ˜ํ”Œ๋“ค(fake)์ด๋‹ค.

10๊ฐœ์˜ ํด๋ž˜์Šค๊ฐ€ ์žˆ๋‹ค๊ณ  ํ–ˆ์„ ๋•Œ, ์›๋ž˜๋Š” one-hot encoding ๋ฐฉ์‹์œผ๋กœ 1000000000๋กœ ํ‘œ์‹œํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์—ฌ๊ธฐ์„œ๋Š” fake image์— ๋Œ€ํ•œ label๋„ ๋”ฐ๋กœ ์‹ ๊ฒฝ์จ์ค€๋‹ค. ๊ทธ๋ ‡๊ธฐ ๋•Œ๋ฌธ์— real image๋Š” 10000000000000000000๋กœ ํ‘œ์‹œํ•˜๊ณ , ์ด์— ํ•ด๋‹นํ•˜๋Š” fake image์˜ label์€ 00000000001000000000๋กœ ํ‘œ์‹œํ•œ๋‹ค.

3-1. Objective Function

Objective1

$L_s$: to estimate the sampling loss, which represents the probability of the sample being real or fake
$L_{cd}$: to estimate the classification loss over the discriminator
$L_{cg}$: to estimate the classification loss over the generator

$L_{cd}$ means that the discriminator classifies samples as real or fake with associated class
$L_{cg}$ means that the generator classifies fake samples as real classes

generator: to maximize the difference of $L_s$ and $L_{cg}$
discriminator: to maximize the sum of $L_s$ and $L_{cd}$

MFC-GAN generator is sampled using a noise vector conditioned on real class labels

Objective2

label์ด ์—†๋Š” ๊ฒฝ์šฐ์—๋Š” ์œ„์™€ ๊ฐ™์ด Vanilla GAN์ฒ˜๋Ÿผ ์ž‘๋™ํ•˜๊ฒŒ ๋œ๋‹ค.

3-2. MFC-GAN vs. FSC-GAN

MFC-GAN: generator is penalized according to how far the generated sample is from the real class label
FSC-GAN: generator is penalized according to how far the generated sample is from the fake class label
=> this promoted early convergence of MFC-GAN

4. Experiments

Algorithm

4-1. Experimental set-up

  1. ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ: tensorflow 1.0, Keras 2.0
  2. ๋น„๊ต๋ชจ๋ธ: SMOTE, AC-GAN, FSC-GAN, ์›๋ฐ์ดํ„ฐ
  3. ๊ณตํ†ต ๋ถ„๋ฅ˜๋ถ„์„๊ธฐ: CNN
  4. ๊ฒฐ๊ณผ๋น„๊ต๊ธฐ์ค€:
  • ์ฃผ๊ด€์ : plausibility of sample(i.e., visual inspection)
  • ๊ฐ๊ด€์ : ๋‹ค์–‘ํ•œ ์ง€ํ‘œ๋“ค์„ ํ†ตํ•ด ๋ถ„๋ฅ˜๊ฒฐ๊ณผ ์„ฑ๋Šฅ ๋น„๊ต

4-2. Dataset

๋ฐ์ดํ„ฐ: MNIST, E-MNIST, SVHN, CIFAR-10

  • ๋Œ€๋ถ€๋ถ„ ๊ฐ•์ œ๋กœ imbalanced data๋กœ ๋งŒ๋“ค์–ด์ฃผ๊ณ ๋‚˜์„œ ํ™œ์šฉํ•จ(ํŠน์ • ํด๋ž˜์Šค ์ž„์˜๋กœ ์„ ํƒ ํ›„ undersampling)
  • MNIST: 2๊ฐœ์˜ ํด๋ž˜์Šค๋ฅผ ์ž„์˜๋กœ ๊ณจ๋ผ ์›๋ฐ์ดํ„ฐ์˜ 1%(50๊ฐœ)์”ฉ๋งŒ ํ™œ์šฉ (run๋งˆ๋‹ค ๋‹ค๋ฅธ ํด๋ž˜์Šค ์„ ํƒ)
  • E-MNIST: ์ด 81๋งŒ ์—ฌ๊ฐœ ์ƒ˜ํ”Œ์—์„œ, 62๊ฐœ ํด๋ž˜์Šค ์ค‘ 21๊ฐœ์˜ ํด๋ž˜์Šค์— ํ•ด๋‹นํ•˜๋Š” ์ƒ˜ํ”Œ๋“ค์ด ๊ฐ๊ฐ 3000๊ฐœ ์ดํ•˜์ด๋‹ค. ์ฆ‰, ์ด๋ฏธ imbalanced์ด๋‹ค. ์—ฌ๊ธฐ์„œ ๊ฐ€์žฅ ์ ์€ 10๊ฐœ์˜ ํด๋ž˜์Šค๋ฅผ ํ™œ์šฉํ•˜์˜€๋‹ค. (G, K, Q, f, j, k, m, p, s, y)
  • SVHN: MNIST์ฒ˜๋Ÿผ 2๊ฐœ์˜ ํด๋ž˜์Šค์—์„œ 50๊ฐœ์”ฉ๋งŒ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ•์ œ๋กœ imbalanced๋ฅผ ๋งŒ๋“ค์–ด์ฃผ์—ˆ๋‹ค. (1,2)
  • CIFAR-10: 10๊ฐœ ํด๋ž˜์Šค ์ค‘ ๋‘ ๊ฐœ์˜ ํด๋ž˜์Šค์—์„œ 1000๊ฐœ ์ค‘์—์„œ 50๊ฐœ์”ฉ๋งŒ ํ™œ์šฉํ•˜์˜€๋‹ค. ์ฆ‰, ๊ฐ•์ œ๋กœ imbalanced๋ฅผ ๋งŒ๋“ค์–ด์ฃผ์—ˆ๋‹ค. (Aeroplane, Automobile)

MNIST์™€ E-MNIST์˜ ๊ฒฝ์šฐ์—๋Š” FSC-GAN์˜ ๊ตฌ์กฐ๋ฅผ ์ฐจ์šฉํ•˜์˜€๊ณ , SVHN๊ณผ CIFAR-10์—์„œ๋Š” AC-GAN์˜ ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•˜์˜€๋‹ค. ๊ทธ๋ฆฌ๊ณ , AC-GAN, FSC-GAN, MFC-GAN ๋ชจ๋‘์˜ ๊ฒฝ์šฐ์— ๊ณต์ •ํ•œ ๋น„๊ต๋ฅผ ์œ„ํ•ด spectral weight normalization์„ ์ƒ์„ฑ๊ธฐ์™€ ๋ถ„๋ฅ˜๊ธฐ ๋ชจ๋‘์— ์ถ”๊ฐ€ํ•ด์ฃผ์—ˆ๋‹ค.

4-3. Image Generation

Figure1

  • network switcher์˜ ์˜๋ฏธ
    Depending on the availability of labels, the network switcher feature enables both models to alternate between two training modes. The switcher is a piece-wise function that oscillates between supervised and unsupervised training.

4-4. Image Classification

CNN์„ ๊ณตํ†ต ๋ถ„๋ฅ˜๋ถ„์„๊ธฐ๋กœ ํ™œ์šฉํ•˜์˜€๋‹ค.

  • MNIST, E-MNIST, SVHN: 3 layers with softmax activation layer (2 convolution layers, 3x3 kernels with 2x2 max-pooling, two filter maps, fully connected layer), ReLuactivated, 0.5 dropout ratio, Adadelta optimizer
  • CIFAR-10: 3 convolution layers, 0.2 dropout ratio, SGD optimizer

5. Results

Figure2


Figure3


Figure4


Figure5

  1. ์ƒ์„ฑ๋œ ์‚ฌ์ง„์˜ ํ€„๋ฆฌํ‹ฐ๊ฐ€ ์šฐ์ˆ˜ํ•จ
    ๋”ฑ ๋ณด๋”๋ผ๋„ MFC-GAN์ด ํ›จ์”ฌ ์šฐ์ˆ˜ํ•œ ์„ฑ๋Šฅ์œผ๋กœ ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“ค์–ด๋ƒ„์„ ๋ˆˆ์œผ๋กœ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. ํŠนํžˆ (c)๋ฅผ ๋ณด๋ฉด, unlabeled๊ฐ€ 5๋งŒ๊ฐœ์˜€๋Š”๋ฐ๋„ ์„ฑ๋Šฅ์ด ๊ดœ์ฐฎ๊ฒŒ ๋‚˜์™”๋‹ค.

  2. ์ปดํ“จํŒ… ํšจ์œจ์„ฑ
    FSC-GAN์€ 500 epoch๊ฐ€ ํ•„์š”ํ•œ ๋ฐ์— ๋น„ํ•ด, MFC-GAN์€ 50 epoch๋งŒ์„ ํ•„์š”๋กœ ํ–ˆ๋‹ค. ์ฆ‰, data augmentation์—๋Š” MFC-GAN์ด ๋ณด๋‹ค ์šฐ์ˆ˜ํ•จ์„ ์•Œ ์ˆ˜ ์žˆ์—ˆ๋‹ค.

  3. ๊ฐ๊ด€์  ์„ฑ๋Šฅ์ง€ํ‘œ์—์„œ๋„ ์šฐ์ˆ˜ํ•จ
    MFC-GAN์—์„œ ๋ฏผ๊ฐ๋„, balanced accuracy, G-mean ๋ชจ๋‘ ๋†’๊ฒŒ ๋‚˜ํƒ€๋‚ฌ๋‹ค. ์ด์— ๋ฐ˜ํ•ด, FSC-GAN์€ ๋ชจ๋“  ๊ฒฝ์šฐ์—์„œ ์„ฑ๋Šฅ ํ–ฅ์ƒ์œผ๋กœ ์ด์–ด์ง€์ง€ ์•Š์•˜๋‹ค. ํŠน์ • ์ƒํ™ฉ์—์„œ๋Š” SMOTE๊ฐ€ ๋” ์šฐ์ˆ˜ํ•œ ๋“ฏ ๋ณด์ด๊ธฐ๋Š” ํ–ˆ์ง€๋งŒ, ์ด๋Š” ๋‹จ์ˆœํžˆ ์†Œ์ˆ˜ ์ง‘๋‹จ์— ๋Œ€ํ•ด ์ƒ˜ํ”Œ์ด ๋” ๋งŽ๊ธฐ ๋•Œ๋ฌธ์ผ ๊ฑฐ๋ผ๊ณ  ์ถ”์ธก๋œ๋‹ค. ์™œ๋ƒํ•˜๋ฉด ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ์ ์€ ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ๋Š” ์„ฑ๋Šฅ์ด ์•ˆ ์ข‹์€ ๊ฒƒ์ด ํ™•์ธ๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

6. Discussion

The fidelity and diversity of MFC-GAN minority samples made classification easier for the CNN. The diversity of generated samples indicates no sign of mode collapse in the model.

ํ•œ๊ณ„์  1. CIFAR-10์—์„œ์˜ ๋ถ€์กฑํ•œ ์„ฑ๋Šฅ
Table3์—์„œ ์•Œ ์ˆ˜ ์žˆ๋“ฏ์ด, CIFAR-10์—์„œ๋Š” ๋ชจ๋“  ๋ชจ๋ธ๋“ค์ด ์„ฑ๋Šฅ์ด ์ข‹์ง€ ์•Š๊ฒŒ ๋‚˜ํƒ€๋‚ฌ๋‹ค. ์ด๋Š” 32x32๋ผ๋Š” ์ž‘์€ ์ด๋ฏธ์ง€ ์‚ฌ์ด์ฆˆ์— ๋น„ํ•ด ๋งŽ์€ ์ •๋ณด๋Ÿ‰์ด ๋“ค์–ด๊ฐ€์žˆ๋Š” ๋ถ„๋ฅ˜์ž‘์—…์ด๊ธฐ ๋•Œ๋ฌธ์ด๋ผ๊ณ  ๋ณด์ธ๋‹ค. (๊ทธ๋ž˜๋„ ๊ทธ์ค‘์—์„œ ๊ทธ๋‚˜๋งˆ MFC-GAN์ด ๋‚ซ๊ธฐ๋Š” ํ–ˆ๋‹ค.)

ํ•œ๊ณ„์ 2. ํŠน์ • ํด๋ž˜์Šค์—์„œ์˜ ๋ถ€์กฑํ•œ ์„ฑ๋Šฅ
ํŠน์ • ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ๋Š” ๋ชจ๋“  ๋ชจ๋ธ๋“ค์ด ๋ชจ๋‘ ์„ฑ๋Šฅ์ด ์•ˆ ์ข‹๊ฒŒ ๋‚˜ํƒ€๋‚ฌ๋‹ค. E-MNIST์˜ ๊ฒฝ์šฐ์—๋Š” m๊ณผ s๊ฐ€ ์ด์— ํ•ด๋‹นํ•œ๋‹ค. ์ด๋Š” s๊ฐ€ 5, S, 2, z์™€ ๋น„์Šทํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋ผ๊ณ  ์ €์ž๋“ค์€ ๋ณด๊ณ  ์žˆ๋‹ค.

Critical Point (MY OWN OPINION)

  1. E-MINST, SVHN, CIFAR-10์—์„œ ๋‹ค๋ฅธ ํด๋ž˜์Šค๊ฐ€ minority class๋กœ์„œ ์„ ์ •์ด ๋˜์—ˆ๋‹ค๋ฉด ์–ด๋–ป๊ฒŒ ๊ฒฐ๊ณผ๊ฐ€ ๋‹ฌ๋ผ์กŒ์„์ง€ ์•Œ ์ˆ˜๊ฐ€ ์—†๋‹ค. ์„ ํƒ๋œ ํด๋ž˜์Šค์˜ ํŠน์„ฑ์— ๋”ฐ๋ฅธ ๊ฒฐ๊ณผ๊ฐ€ ์•„๋‹ˆ๋ผ๊ณ  ๊ฒฐ๋ก  ์ง€์„ ์ˆ˜๊ฐ€ ์—†๋‹ค๋Š” ํ•œ๊ณ„์ ์ด ์žˆ๋‹ค.

  2. ๊ธฐ๋ณธ CNN ๋ชจํ˜• ๋ง๊ณ  ResNet๋‚˜ YOLO์™€ ๊ฐ™์ด ๋‹ค๋ฅธ architecture๋ฅผ ์‚ฌ์šฉํ–ˆ์œผ๋ฉด ์–ด๋• ์„๊นŒ๋ผ๋Š” ์ƒ๊ฐ์ด ์žˆ๋‹ค.

์ฐธ๊ณ ์‚ฌ์ดํŠธ

[1] Spectral Normalization(1)
[2] Spectral Normalization(2)

Further Study Needed

  • [1] Few-Shot Classifier GAN (FSC-GAN)
    Ali-Gombe, A., Elyan, E., Savoye, Y., & Jayne, C. (2018, July). Few-shot classifier GAN. In 2018 International Joint Conference on Neural Networks (IJCNN) (pp. 1-8). IEEE.
  • [2] Advanced GANs
    ์‚ฌ์ดํŠธ1
    ์‚ฌ์ดํŠธ2