[Paper Review] Progressive Growing of GANs for Improved Quality, Stability, and Variation
GAN (Generative Adversarial Network)Generator(진짜 같은 이미지 생성하려는 생성 모델)와 Discriminator의 경쟁을 통해 학습(적대적 학습)하여 Generator에서 실제 이미지와 구분이 되지 않는 거짓 이미지를 생성
Minmax problem
Generator | Discriminator |
노이즈(z)를 입력받아 latent code로부터 실제 이미지와 유사한 샘플(ex. 이미지) 생성 | 평가(생성 분포가 훈련 분포와 구별 가능 여부를 판별하는 함수) 수행하기 위해 discriminator network 훈련 |
Generator가 학습되면 폐기되는 adaptive loss 함수(적응 손실 함수) |
⇒ 앞선 GAN의 문제를 해결하기 위해 PGGAN에서는 Generator와 Discriminator를 점진적으로 성장시키면서 학습시킴. 즉, 쉬운 저해상도 이미지를 시작으로 새로운 층을 추가하면서 고해상도 디테일을 생성함.
⇒ 학습 속도 향상 & 고해상도에서 안정성 개선
저해상도 4 x 4 픽셀 이미지를 시작으로 generator(G)와 discriminator(D)가 학습하고, 학습이 진행됨에 따라 G와 D에 layer를 추가해가면서 생성된 이미지의 해상도를 증가시킨다.
N x N: N x N 공간 해상도에서 작동하는 covolution layer
⇒ 이는 고해상도에서 안정적으로 이미지 합성이 가능하고 학습 속도를 높일 수 있다.
1024 x 1024까지 progressive growing을 통해 생성된 이미지
(a) 16 x 16 이미지를 (c) 32 x 32 이미지로 해상도를 늘리는 과정
generator에서 3 x 3 convolution layer 후 feature vector에 대해 pixelwise normalization 적용
wi: 가중치 c: He’s initializer에서의 per-layer normalization 상수
→ N(0, 1) 초기화 후 런타임에서 명시적으로 가중치 스케일링
일반적으로 사용되는 adaptive stochastic gradient descent 방법은 기울기 업데이트를 추정 표준편차로 정규화하여 매개변수의 규모와 무관하게 업데이트가 되어 일부 파라미터의 동적 범위가 다른 파라미터보다 큰 경우 조정하는데 시간이 오래 걸림
equalized learning rate는 모든 가중치에 대해 동적 범위, 즉 학습 속도가 동일함을 보장함
local response normalization
ax,y bx,y : 각 픽셀(x,y)의 원본 및 정규화된 feature vector N: # of feature maps
한 GAN의 결과를 다른 GAN과 비교하기 위해 대규모 이미지 컬렉션에서 지표를 계산하는 자동화된 방법 사용
기존 MS-SSIM 방법 한계: 대규모 mode collapse는 안정적으로 찾지만 색상이나 질감에서의 다양성 손실과 같은 작은 영향에 대처를 못함
⇒ 좋은 generator는 모든 scale에 대해 학습 set과 로컬 이미지 구조가 비슷하다는 multi-scale statistical similarity 사용 제안
https://github.com/tkarras/progressive_growing_of_gans
# Minibatch standard deviation.
def minibatch_stddev_layer(x, group_size=4):
with tf.variable_scope('MinibatchStddev'):
group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size.
s = x.shape # [NCHW] Input shape.
y = tf.reshape(x, [group_size, -1, s[1], s[2], s[3]]) # [GMCHW] Split minibatch into M groups of size G.
y = tf.cast(y, tf.float32) # [GMCHW] Cast to FP32.
y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMCHW] Subtract mean over group.
y = tf.reduce_mean(tf.square(y), axis=0) # [MCHW] Calc variance over group.
y = tf.sqrt(y + 1e-8) # [MCHW] Calc stddev over group.
y = tf.reduce_mean(y, axis=[1,2,3], keepdims=True) # [M111] Take average over fmaps and pixels.
y = tf.cast(y, x.dtype) # [M111] Cast back to original data type.
y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [N1HW] Replicate over group and pixels.
return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap.
# Pixelwise feature vector normalization.
def pixel_norm(x, epsilon=1e-8):
with tf.variable_scope('PixelNorm'):
return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon)
# Generator network used in the paper.
def G_paper(
latents_in, # First input: Latent vectors [minibatch, latent_size].
labels_in, # Second input: Labels [minibatch, label_size].
num_channels = 1, # Number of output color channels. Overridden based on dataset.
resolution = 32, # Output resolution. Overridden based on dataset.
label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
fmap_base = 8192, # Overall multiplier for the number of feature maps.
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
fmap_max = 512, # Maximum number of feature maps in any layer.
latent_size = None, # Dimensionality of the latent vectors. None = min(fmap_base, fmap_max).
normalize_latents = True, # Normalize latent vectors before feeding them to the network?
use_wscale = True, # Enable equalized learning rate?
use_pixelnorm = True, # Enable pixelwise feature vector normalization?
pixelnorm_epsilon = 1e-8, # Constant epsilon for pixelwise feature vector normalization.
use_leakyrelu = True, # True = leaky ReLU, False = ReLU.
dtype = 'float32', # Data type to use for activations and outputs.
fused_scale = True, # True = use fused upscale2d + conv2d, False = separate upscale2d layers.
structure = None, # 'linear' = human-readable, 'recursive' = efficient, None = select automatically.
is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.
**kwargs): # Ignore unrecognized keyword args.
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
def PN(x): return pixel_norm(x, epsilon=pixelnorm_epsilon) if use_pixelnorm else x
if latent_size is None: latent_size = nf(0)
if structure is None: structure = 'linear' if is_template_graph else 'recursive'
act = leaky_relu if use_leakyrelu else tf.nn.relu
latents_in.set_shape([None, latent_size])
labels_in.set_shape([None, label_size])
combo_in = tf.cast(tf.concat([latents_in, labels_in], axis=1), dtype)
lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)
# Building blocks.
def block(x, res): # res = 2..resolution_log2
with tf.variable_scope('%dx%d' % (2**res, 2**res)):
if res == 2: # 4x4
if normalize_latents: x = pixel_norm(x, epsilon=pixelnorm_epsilon)
with tf.variable_scope('Dense'):
x = dense(x, fmaps=nf(res-1)*16, gain=np.sqrt(2)/4, use_wscale=use_wscale) # override gain to match the original Theano implementation
x = tf.reshape(x, [-1, nf(res-1), 4, 4])
x = PN(act(apply_bias(x)))
with tf.variable_scope('Conv'):
x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))
else: # 8x8 and up
if fused_scale:
with tf.variable_scope('Conv0_up'):
x = PN(act(apply_bias(upscale2d_conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))
else:
x = upscale2d(x)
with tf.variable_scope('Conv0'):
x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))
with tf.variable_scope('Conv1'):
x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))
return x
def torgb(x, res): # res = 2..resolution_log2
lod = resolution_log2 - res
with tf.variable_scope('ToRGB_lod%d' % lod):
return apply_bias(conv2d(x, fmaps=num_channels, kernel=1, gain=1, use_wscale=use_wscale))
# Linear structure: simple but inefficient.
if structure == 'linear':
x = block(combo_in, 2)
images_out = torgb(x, 2)
for res in range(3, resolution_log2 + 1):
lod = resolution_log2 - res
x = block(x, res)
img = torgb(x, res)
images_out = upscale2d(images_out)
with tf.variable_scope('Grow_lod%d' % lod):
images_out = lerp_clip(img, images_out, lod_in - lod)
# Recursive structure: complex but efficient.
if structure == 'recursive':
def grow(x, res, lod):
y = block(x, res)
img = lambda: upscale2d(torgb(y, res), 2**lod)
if res > 2: img = cset(img, (lod_in > lod), lambda: upscale2d(lerp(torgb(y, res), upscale2d(torgb(x, res - 1)), lod_in - lod), 2**lod))
if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1))
return img()
images_out = grow(combo_in, 2, resolution_log2 - 2)
assert images_out.dtype == tf.as_dtype(dtype)
images_out = tf.identity(images_out, name='images_out')
return
# Discriminator network used in the paper.
def D_paper(
images_in, # Input: Images [minibatch, channel, height, width].
num_channels = 1, # Number of input color channels. Overridden based on dataset.
resolution = 32, # Input resolution. Overridden based on dataset.
label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
fmap_base = 8192, # Overall multiplier for the number of feature maps.
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
fmap_max = 512, # Maximum number of feature maps in any layer.
use_wscale = True, # Enable equalized learning rate?
mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable.
dtype = 'float32', # Data type to use for activations and outputs.
fused_scale = True, # True = use fused conv2d + downscale2d, False = separate downscale2d layers.
structure = None, # 'linear' = human-readable, 'recursive' = efficient, None = select automatically
is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.
**kwargs): # Ignore unrecognized keyword args.
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
if structure is None: structure = 'linear' if is_template_graph else 'recursive'
act = leaky_relu
images_in.set_shape([None, num_channels, resolution, resolution])
images_in = tf.cast(images_in, dtype)
lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)
# Building blocks.
def fromrgb(x, res): # res = 2..resolution_log2
with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)):
return act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, use_wscale=use_wscale)))
def block(x, res): # res = 2..resolution_log2
with tf.variable_scope('%dx%d' % (2**res, 2**res)):
if res >= 3: # 8x8 and up
with tf.variable_scope('Conv0'):
x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))
if fused_scale:
with tf.variable_scope('Conv1_down'):
x = act(apply_bias(conv2d_downscale2d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale)))
else:
with tf.variable_scope('Conv1'):
x = act(apply_bias(conv2d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale)))
x = downscale2d(x)
else: # 4x4
if mbstd_group_size > 1:
x = minibatch_stddev_layer(x, mbstd_group_size)
with tf.variable_scope('Conv'):
x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))
with tf.variable_scope('Dense0'):
x = act(apply_bias(dense(x, fmaps=nf(res-2), use_wscale=use_wscale)))
with tf.variable_scope('Dense1'):
x = apply_bias(dense(x, fmaps=1+label_size, gain=1, use_wscale=use_wscale))
return x
# Linear structure: simple but inefficient.
if structure == 'linear':
img = images_in
x = fromrgb(img, resolution_log2)
for res in range(resolution_log2, 2, -1):
lod = resolution_log2 - res
x = block(x, res)
img = downscale2d(img)
y = fromrgb(img, res - 1)
with tf.variable_scope('Grow_lod%d' % lod):
x = lerp_clip(x, y, lod_in - lod)
combo_out = block(x, 2)
# Recursive structure: complex but efficient.
if structure == 'recursive':
def grow(res, lod):
x = lambda: fromrgb(downscale2d(images_in, 2**lod), res)
if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1))
x = block(x(), res); y = lambda: x
if res > 2: y = cset(y, (lod_in > lod), lambda: lerp(x, fromrgb(downscale2d(images_in, 2**(lod+1)), res - 1), lod_in - lod))
return y()
combo_out = grow(2, resolution_log2 - 2)
assert combo_out.dtype == tf.as_dtype(dtype)
scores_out = tf.identity(combo_out[:, :1], name='scores_out')
labels_out = tf.identity(combo_out[:, 1:], name='labels_out')
return scores_out, labels_out
Sliced-Wasserstein
평가에 사용한 지표: sliced Wasserstein distance (SWD), multi-scale structural similarity (MSSSIM) - A dataset, LSUN BEDROOM
표 1 생성된 이미지와 학습 이미지 사이 Sliced Wasserstein distance (SWD)와 multi-scale structural similarity (MS-SSIM)
실험 결과 생성된 CELEB-A 이미지
⇒ 결론적으로, 출력 간의 변화만 비교하는 MS-SSIM보다 SWD가 생성된 이미지의 분포가 훈련 세트와 더 유사하도록 찾을 수 있음
⇒ 점진적 성장을 사용하지 않을 경우, 생성기와 판별기의 모든 계층은 대규모 변화와 소규모 세부 사항에 대한 간결한 중간 표현을 동시에 찾아야 하는 과제를 안게 되지만, 점진적으로 성장하는 경우 기존의 저해상도 레이어는 이미 초기에 수렴되었을 가능성이 높으므로 네트워크는 새로운 레이어가 도입됨에 따라 점점 더 작은 규모의 효과를 통해 표현을 개선하는 작업만 수행
⇒ progressive growing은 훨씬 더 나은 최적에 수렴하고 총 훈련 시간이 약 2배 단축
네트워크에서 생성된 일부 1024 × 1024 이미지
고품질의 CELEBA 데이터셋 생성
⇒ 해당 데이터셋을 사용하여 높은 출력 해상도를 강력하고 효율적인 방식으로 처리 가능
PGGAN을 사용한 결과와 이전 연구에서의 결과로, 다른 방법을 적용한 결과보더 품질이 더 좋음을 확인할 수 있음
데이터셋
네트워크와 학습 설정
Progression
WGAN-GP’s regularization
γ = 750로 customize (원래는 γ=1.0 즉 1-Lipschiz) → 빠른 transition으로 ghost 최소화 가능
장점
단점