본문 바로가기
딥러닝/논문리뷰

[논문 코드 구현]global local image completion

by 달죽 2020. 11. 23.
반응형

2020/11/03 - [딥러닝/논문리뷰] - [Image impainting]Globally and Locally Consistent Image Completion

import sys, os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]='0'

import tensorflow as tf
import numpy as np
import glob
import pathlib
import matplotlib.pyplot as plt
import PIL

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)

print('''+++ env info +++
Python version : {},
Tensorflow version : {},
Keras version : {}
'''.format(sys.version, tf.__version__, tf.keras.__version__))

print('TF GPU available test :', tf.config.list_physical_devices('GPU'))
#print(tf.test.is_gpu_available(cuda_only=False, min_cuda_compute_capability=None))
# os.environ["CUDA_VISIBLE_DEVICES"]="1"

from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise
from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import LeakyReLU, ReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import losses
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing import image
import tensorflow.keras.backend as K
from tensorflow.keras.utils import plot_model

def block_patch(input_, mode, patch_size=50, margin=10):
	
    shape = input_.get_shape().as_list()

    if mode == 'central_box':
        #create patch in central size
        pad_size = tf.constant([patch_size, patch_size], dtype=tf.int32)
        patch = tf.zeros([shape[0], pad_size[0], pad_size[1], shape[-1]], dtype=tf.float32)

        h_ = tf.constant([int(shape[1]/4)], dtype=tf.int32)[0]
        w_ = tf.constant([int(shape[2]/4)], dtype=tf.int32)[0]
        coord = h_, w_
        padding = [[0, 0], [h_, shape[1]-h_-pad_size[0]], [w_, shape[2]-w_-pad_size[1]], [0, 0]]
        padded = tf.pad(patch, padding, "CONSTANT", constant_values=1)

        result = tf.multiply(input_, padded)

    if mode == 'random_box':
        #create patch in random box
        pad_size = tf.constant([patch_size, patch_size], dtype=tf.int32)
        patch = tf.zeros([shape[0], pad_size[0], pad_size[1], shape[-1]], dtype=tf.float32)

        h_ = tf.random.uniform([1], minval=margin, maxval=shape[1]-pad_size[0]-margin, dtype=tf.int32)[0]
        w_ = tf.random.uniform([1], minval=margin, maxval=shape[2]-pad_size[1]-margin, dtype=tf.int32)[0]
        coord = h_, w_
        padding = [[0, 0], [h_, shape[1]-h_-pad_size[0]], [w_, shape[2]-w_-pad_size[1]], [0, 0]]
        padded = tf.pad(patch, padding, "CONSTANT", constant_values=1)

        result = tf.multiply(input_, padded)    
        
    return result, padded, coord, pad_size
	
def configure_for_performance(ds, batch_size, buffer_size):
    ds = ds.cache()
    ds = ds.shuffle(buffer_size=1024)
    ds = ds.batch(batch_size, drop_remainder=False)
    ds = ds.prefetch(buffer_size=buffer_size)
    return ds
	
def process_path(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32) * 2 - 1
    img = tf.image.resize(img, [img_h, img_w])
    return img
	
def sample_images(epoch, imgs, save_path):
    
    masked_imgs, mask, coord, pad_size = block_patch(imgs, mode='random_box', patch_size=patch_size, margin=margin)
    y1, y2, x1, x2 = coord[0].numpy(), coord[0].numpy()+pad_size[0].numpy(), coord[1].numpy(), coord[1].numpy()+pad_size[1].numpy()
    inpainted_imgs = G.predict(masked_imgs)
    inpainted_fields = inpaint_crop(inpainted_imgs, coord, pad_size)
    
    imgs = 0.5 * imgs + 0.5
    masked_imgs = 0.5 * masked_imgs + 0.5
    inpainted_imgs = 0.5 * inpainted_imgs + 0.5
    inpainted_fields = 0.5 * inpainted_fields + 0.5

    r, c = 4, 6
    fig, axs = plt.subplots(r, c)
    for i in range(c):
        axs[0,i].imshow(imgs[i])
        axs[0,i].axis('off')
        axs[1,i].imshow(masked_imgs[i])
        axs[1,i].axis('off')
        axs[2,i].imshow(inpainted_imgs[i])
        axs[2,i].axis('off')
        filled_in = tf.identity(masked_imgs)[i].numpy()
        filled_in[y1:y2, x1:x2, :] = inpainted_fields[i].numpy()
        axs[3,i].imshow(filled_in)
        axs[3,i].axis('off')
        
    if not os.path.exists('images/{}'.format(save_path)):
        os.mkdir('images/{}'.format(save_path))
    
    fig.savefig("images/{}/{}.png".format(save_path,epoch))
    plt.close()

def inpaint_crop(img, coord, pad_size):
    crop_box = tf.image.crop_to_bounding_box(img, coord[0], coord[1], pad_size[0], pad_size[1])
    crop_bbox = tf.image.resize(crop_box, (inpainted_h, inpainted_w))
    return crop_bbox
	
local_shape, global_shape = (64,64,3), (128,128,3)

img_h, img_w, channel = 128, 128, 3
inpainted_h, inpainted_w = 64, 64

batch_size = 16
AUTOTUNE = tf.data.experimental.AUTOTUNE
buffer_size = AUTOTUNE
epochs = 60000
sample_interval = 10

margin = 5
patch_size = 64

#data_dir = '/home/user_5/bg/python/GDL_code/data/celeb/img_align_celeba/img_align_celeba'
data_dir = '/home/user_5/bg/datasets/sample'
save_path = 'GL_celeb_7'

# 데이터 로드
image_count = len(list(glob.glob(data_dir+'/*.jpg')))
list_ds = tf.data.Dataset.list_files(str(data_dir+"/*"), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)
list_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)

val_size = int(image_count * 0.01)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)

train_len = tf.data.experimental.cardinality(train_ds).numpy()
val_len = tf.data.experimental.cardinality(val_ds).numpy()
print(train_len)
print(val_len)

train_ds = configure_for_performance(train_ds, batch_size, buffer_size)
val_ds = configure_for_performance(val_ds, batch_size, buffer_size)

for i in iter(train_ds):
    sample = i
    break
	
result, padded, coord, pad_size = block_patch(sample, mode='random_box', patch_size=patch_size, margin=margin)
											  
plt.figure(figsize=(10, 5))
for i in range(6):
    ax = plt.subplot(2,3,i+1)
    plt.imshow(result[i].numpy()*0.5+0.5)
    plt.axis("off")
#plt.close()

result.shape, padded.shape, coord, pad_size

def build_generator(global_shape):
    
    inputs = tf.keras.Input(shape=global_shape) # (None, 128, 128, 3)
    
    # Encoder
    x = Conv2D(64, kernel_size=5, strides=1, padding='same', activation='relu')(inputs) # (None, 64, 64, 64)
    
    x = Conv2D(128, kernel_size=3, strides=2, padding='same', activation='relu')(x)
    x = Conv2D(128, kernel_size=3, strides=1, padding='same', activation='relu')(x)
    
    x = Conv2D(256, kernel_size=3, strides=2, padding='same', activation='relu')(x)
    x = Conv2D(256, kernel_size=3, strides=1, padding='same', activation='relu')(x)
    x = Conv2D(256, kernel_size=3, strides=1, padding='same', activation='relu')(x)
    
    # Dilated_Conv
    x = Conv2D(256, kernel_size=3, strides=1, dilation_rate=2, padding='same', activation='relu')(x)
    x = Conv2D(256, kernel_size=3, strides=1, dilation_rate=4, padding='same', activation='relu')(x)
    x = Conv2D(256, kernel_size=3, strides=1, dilation_rate=8, padding='same', activation='relu')(x)
    x = Conv2D(256, kernel_size=3, strides=1, dilation_rate=16, padding='same', activation='relu')(x)

    x = Conv2D(256, kernel_size=3, strides=1, padding='same', activation='relu')(x)
    x = Conv2D(256, kernel_size=3, strides=1, padding='same', activation='relu')(x)
    
    x = Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu')(x)
    x = Conv2D(128, kernel_size=3, strides=1, padding='same', activation='relu')(x)

    x = Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu')(x)
    x = Conv2D(32, kernel_size=3, strides=1, padding='same', activation='relu')(x)
    x = Conv2D(channel, kernel_size=3, strides=1, padding='same', activation='tanh')(x)    

    x = tf.keras.Model(inputs=inputs, outputs=x, name='G')

    return x
	
def build_discriminator(local_shape, global_shape):

    local_inputs = tf.keras.Input(shape=local_shape, name='local_input') # (None, 64, 64, 3)
    global_inputs = tf.keras.Input(shape=global_shape, name='global_input') # (None, 64, 64, 3)
    
    local = Conv2D(64, kernel_size=5, strides=2, padding='same', activation='relu')(local_inputs) 
    local = Conv2D(128, kernel_size=5, strides=2, padding='same', activation='relu')(local) 
    local = Conv2D(256, kernel_size=5, strides=2, padding='same', activation='relu')(local) 
    local = Conv2D(512, kernel_size=5, strides=2, padding='same', activation='relu')(local) 
    local = Conv2D(512, kernel_size=5, strides=2, padding='same', activation='relu')(local) 
    local = Flatten()(local) 
    local = Dense(1024)(local)

    global_ = Conv2D(64, kernel_size=5, strides=2, padding='same', activation='relu')(global_inputs) 
    global_ = Conv2D(128, kernel_size=5, strides=2, padding='same', activation='relu')(global_) 
    global_ = Conv2D(256, kernel_size=5, strides=2, padding='same', activation='relu')(global_) 
    global_ = Conv2D(512, kernel_size=5, strides=2, padding='same', activation='relu')(global_) 
    global_ = Conv2D(512, kernel_size=5, strides=2, padding='same', activation='relu')(global_) 
    global_ = Conv2D(512, kernel_size=5, strides=2, padding='same', activation='relu')(global_) 
    global_ = Flatten()(global_) 
    global_ = Dense(1024)(global_)

    D_output = tf.keras.layers.concatenate([local, global_])
    D_output = Dense(1, activation='sigmoid')(D_output)
    D_output = tf.keras.Model(inputs=[local_inputs, global_inputs], outputs=D_output, name='D')
    
    return D_output
	
D = build_discriminator(local_shape, global_shape)
D.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])


G = build_generator(global_shape)
G_input = Input(global_shape)
inpainted_img = G(G_input)
inpainted_field = inpaint_crop(inpainted_img, coord, pad_size)
validity = D([inpainted_field, inpainted_img])


D.trainable = False
GL = tf.keras.Model(G_input , [inpainted_field, validity], name='GL')
GL.compile(loss=['mse', 'binary_crossentropy'], loss_weights=[0.9996, 0.0004], optimizer=Adam())

GL.summary()

D_losses = []
D_acc = []
G_losses = []
G_mse = []

# Adversarial ground truths
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
    
    for train_batch in iter(train_ds):

        # ---------------------
        #  Train Discriminator
        # ---------------------
        # Select a random batch of images

        masked_imgs, mask, coord, pad_size = block_patch(train_batch, mode='random_box', patch_size=patch_size, margin=margin)
        y1, y2, x1, x2 = coord[0].numpy(), coord[0].numpy()+pad_size[0].numpy(), coord[1].numpy(), coord[1].numpy()+pad_size[1].numpy()
        
        # Generate a batch of new images
        inpainted_imgs = G.predict(masked_imgs)
        inpainted_fields = inpaint_crop(inpainted_imgs, coord, pad_size)
        real_fields = inpaint_crop(train_batch, coord, pad_size)
        
        filled_in = tf.identity(masked_imgs).numpy() # tensor copy
        filled_in[:, y1:y2, x1:x2, :] = inpainted_fields.numpy()
        filled_in = tf.constant(filled_in)
        
        # Train the discriminator
        D.trainable = True
        d_loss_real = D.train_on_batch([real_fields, train_batch], real)
        d_loss_fake = D.train_on_batch([inpainted_fields, filled_in], fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


        # ---------------------
        #  Train Generator
        # ---------------------
        D.trainable = False
        g_loss = GL.train_on_batch(masked_imgs, [real_fields, real])

        # Plot the progress
        D_losses.append(d_loss[0])
        D_acc.append(d_loss[1])
        G_losses.append(g_loss[0])
        G_mse.append(g_loss[1])
        # print('학습중')
        # If at save interval => save generated image samples
    if epoch % sample_interval == 0:
        val_batch = next(iter(val_ds))
        print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))
        val = val_batch[:6]
        sample_images(epoch, val, save_path)
반응형

댓글