[DL] 06. 새로운 데이터를 생성하는 법, GAN?

2020. 12. 8. 06:46ML&DL

순천향대학교 빅데이터 공학과 김정현 교수님의 강의를 바탕으로 정리한 글이며 수업 자료의 저작권 문제로 인해 수업 자료를 직접 이용하지 않았음을 먼저 밝힙니다. 문제가 될 시 바로 삭제하겠습니다.

 

 

이번 글은 새로운 데이터를 만들어내는 Generative Adversarial Network 모형에 대해 알아보겠습니다.

 

순서

1. Generative Model 학습 방법

2. GAN?

3. GAN 학습 방법

4. GAN 코드(Tensorflow)

 

Generative Model 학습 방법

GAN 모델에 대해 알아보기 전 Generative Model에 대해 알아보겠습니다. Generative Model은 랜덤 한 latent code를 입력으로 넣으면 출력으로 새로운 데이터가 나오는 학습 방법입니다. 그럼 latent code가 무엇일까요? 많은 학습 데이터가 주어졌을 때 관찰을 잘 설명할 수 있는 공간을 latent space(잠재 공간)이라고 합니다. 실제 관찰 공간보다 잠재 공간은 feature의 유용성을 확인한 후 작아질 수도 있습니다. 이렇듯 학습을 위해 정말 필요한 feature들만 추출하여 모델에 적용합니다. 이처럼 latent code 확률 밀도 함수에 적용된 특징들의 확률 값입니다. GAN 모델의 목표 실제 데이터의 분포에 근사하는 모델을 만드는 것으로 실제 데이터에 대한 확률분포(확률 밀도 함수)와 비슷하게 따라가도록 학습합니다. 그래서 파란선의 실제 데이터의 확률 분포와 빨간 선의 모델에서 나온 확률 분포의 차이가 적도록 만들어야 합니다. 

 

 

GAN?

GAN은 색이 없는 사진에 색을 더해주거나 진짜 같은 가짜 데이터를 만들어내는 방법으로 실제 사진을 변형해서 만들어내는 방법입니다. 그림과 같이 실제로 있을만한 사람의 얼굴 사진을 만들어내거나 동물의 사진을 만들어낼 수 있습니다.

GAN 활용사례 [1]
GAN 활용사례 [2]

 

가짜 데이터에 대한 정답은 주어지지 않는데 어떻게 하면 진짜 같은 데이터를 만들어낼 수 있을까요? 어떻게 그럴듯한 가짜 이미지다라고 정의할 수 있을까요? GAN은 가짜 데이터를 실제 데이터처럼 만들어내는 Generator, 가짜 데이터와 실제 데이터를 판별하는 Discriminator 두 부분으로 나눠져 있습니다. 주요 개념은 Generator과 Discriminator이 서로 대립(Adversarial)하여 성능을 개선해나가는 방식이라는 것입니다.

 

 

GAN 모델 학습 방법 

 

Generator, Discriminator 두 부분 중 어떤 부분이 먼저 학습되어야 할까요? 무엇이 먼저 시작하듯 학습만 잘되면 되기 때문에 둘 중 어떤 부분을 먼저 하든 상관이 없습니다. 처음엔 둘 다 랜덤 값을 초기값으로 사용하기 때문에 엉망입니다. 그럼 어떻게 성능을 높여나갈까요? cost function을 이용하여 조절합니다.

 

 

 

Generator은 랜덤 하게 시도해서 가짜 데이터를 실제 데이터처럼 잘 만들어내도록 loss가 최소가 되도록 만들어야 하고 Discriminator은 가짜 데이터와 실제 데이터를 잘 구분할 수 있도록 loss가 최대가 되도록 만들어야 합니다. 즉, V가 최소화되도록 G를 조절하고 V가 최대화되도록 D를 조절하는 것입니다. 

 

적당한 크기의 random vector를 입력 데이터로 사용하여 Generator에서 가짜 데이터를 만들어냅니다. 이때 가짜 데이터의 shape과 실제 데이터의 shape이 꼭 맞지 않아도 상관없습니다. 가짜 데이터에서 나온 값이 판별 기를 거쳐 1에 가까운 값이 나오도록 합니다. 

 

Discriminator은 실제 데이터와 가짜 데이터를 구분합니다. 출력 데이터는 1차원으로 sigmoid 함수를 거쳐서  0.5를 기준으로 classification을 진행합니다. 가짜 데이터이면 0에 가까운 값을, 실제 데이터이면 1에 가까운 값이 나옵니다. GAN은 진짜 이미지와 같은 가짜 이미지를 만들어내야 되기 때문에 1에 가깝게 나오는 것이 좋습니다. 이 과정이 1번의 epoch입니다. epoch를 반복하며 그럴듯한 가짜 데이터를 만들어내는 것입니다. 

 

cost function수식에서 Discriminator 관점에서 보면 진짜 이미지가 나오면 1에 가깝도록 가짜 이미지가 나오면 0에 가깝도록 학습을 합니다. Generative 관점에서 보면 왼쪽 수식에는 관여할 수 없기 때문에 실제 이미지에 들어가는 값이 1일 때 최소가 됩니다. Generative는 가짜 이미지를 받았을 때 최대 1에 가까운 값을 내놓도록 학습을 시작합니다. 

 

 

GAN 코드

GAN 코드를 통해 학습 과정을 살펴보면 좀 더 이해가 잘 될 것입니다. MNIST 데이터를 이용하여 그럴듯한 손글씨를 만들어보겠습니다.

 

import tensorflow as tf
# tensorflow2 버전에서 1버전 코드 사용하기 
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

tf.reset_default_graph()

import matplotlib.pyplot as plt
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data

# 데이터 불러오기 
mnist = input_data.read_data_sets("./mnist/data/", one_hot=True)

import os
currentPath = os.getcwd()
print(currentPath)
import os
currentPath = os.getcwd()
print(currentPath)

# parameter
total_epoch = 100
batch_size = 100
n_hidden = 256
n_input = 28 * 28
n_noise = 128
n_class = 10  # 1~9 까지 총 10개 숫자 
X = tf.placeholder(tf.float32, [None, n_input])

# 생성자가 label을 갖고 있다면 더 좋은 성능을 가질 수 있지 않을까?
Y = tf.placeholder(tf.float32, [None, n_class])  
Z = tf.placeholder(tf.float32, [None, n_noise])

print("X 차원: {}".format(X.shape))
print("Y 차원: {}".format(Y.shape))
print("Z 차원: {}".format(Z.shape))
# 가짜 이미지 생성기 
def generator(noise, labels):
    with tf.variable_scope('generator'):
        inputs = tf.concat([noise, labels], 1)  # 입력 데이터에 label을 추가해서 높은 성능을 기대하자. 
        hidden = tf.layers.dense(inputs, n_hidden, activation=tf.nn.relu)
        output = tf.layers.dense(hidden, n_input, activation=tf.nn.sigmoid)
    return output

#분류 생성기 
def discriminator(inputs, labels, reuse=None):
    with tf.variable_scope('discriminator') as scope:
        if reuse:
            scope.reuse_variables()   # 변수를 재사용하며 학습 
        inputs = tf.concat([inputs, labels], 1)
        hidden = tf.layers.dense(inputs, n_hidden, activation=tf.nn.relu)
        output = tf.layers.dense(hidden, 1, activation=None)
    return output

# 노이즈 생성
def get_noise(batch_size, n_noise):  # batch size 만큼 균등 분폴르 따르는 노이즈 생성 
    return np.random.uniform(-1., 1., size=[batch_size, n_noise])
G = generator(Z, Y)  # 그럴듯한 가짜 이미지 생성 
D_real = discriminator(X, Y)   # 진짜 이미지 판별
D_gene = discriminator(G, Y, True)  # 가짜 이미지 판별 
# tf.ones_like = D_real과 같은 차원의 행렬 데이터를 만들어냄. 
# tf.zeros_like = D_real과 같은 차원의 행렬 데이털르 만드는데 데이터 내 값은 모두 0
loss_D_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real, labels=tf.ones_like(D_real)))
loss_D_gene = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_gene, labels=tf.zeros_like(D_gene)))

loss_D = loss_D_real + loss_D_gene
loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_gene, labels=tf.ones_like(D_gene)))

vars_D = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
vars_G = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')

# loss를 구하는 부분은 NN이 독립적으로 다른 NN이므로 구분 
train_D = tf.train.AdamOptimizer().minimize(loss_D, var_list=vars_D)
train_G = tf.train.AdamOptimizer().minimize(loss_G, var_list=vars_G)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
total_batch = int(mnist.train.num_examples/batch_size)
loss_val_D, loss_val_G = 0, 0

for epoch in range(total_epoch):
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        noise = get_noise(batch_size, n_noise)

        _, loss_val_D = sess.run([train_D, loss_D], feed_dict={X: batch_xs, Y: batch_ys, Z: noise})
        _, loss_val_G = sess.run([train_G, loss_G], feed_dict={Y: batch_ys, Z: noise})

    print('Epoch:', '%04d' % epoch, 'D loss:  {:.4}'.format(loss_val_D), 'G loss: {:.4}'.format(loss_val_G))

    if epoch == 0 or (epoch + 1) % 10 == 0:
        sample_size = 10
        noise = get_noise(sample_size, n_noise)
        samples = sess.run(G, feed_dict={Y: mnist.test.labels[:sample_size], Z: noise})
        
        fig, ax = plt.subplots(2, sample_size, figsize=(sample_size, 2))
        
        for i in range(sample_size):
            ax[0][i].set_axis_off()
            ax[1][i].set_axis_off()
            ax[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
            ax[1][i].imshow(np.reshape(samples[i], (28, 28)))
        plt.savefig('samples2/{}.png'.format(str(epoch).zfill(3)), bbox_inches='tight')
        plt.close(fig)
        print('Optimization Completed!')

 

참고

[1] [4차 산업 생생 용어] 2개의 인공지능이 경쟁하며 발전한다...'GAN'

[2] What is GAN, tha AI technique that makes computers creative?

[3] GAN의 기초(GAN, DCGAN, WGAN, CGAN)

728x90