GAN生成对抗网络-PIX2PIXGAN原理与基本实现-图像翻译09
什么是pix2pix Gan
普通的GAN接收的G部分的輸入是隨機向量,輸出是圖像
;D部分接收的輸入是圖像(生成的或是真實的),輸出是對或
者錯。這樣G和D聯手就能輸出真實的圖像。
對于圖像翻譯任務來說,它的G輸入顯然應該是一張圖x,
輸出當然也是一張圖y。
不需要添加隨機輸入。
對于圖像翻譯這些任務來說,輸入和輸出之間會共享很多
的信息。比如輪廓信息是共享的。
如果使用普通的卷積神經網絡,那么會導致每一層都承載
保存著所有的信息,這樣神經網絡很容易出錯。
U-Net也是Encoder-Decoder模型,是變形的EncoderDecoder模型。
所謂的U-Net是將第i層拼接到第n-i層,這樣做是因為第i層
和第n-i層的圖像大小是一致的,可以認為他們承載著類似
的信息。
但是D的輸入卻應該發生一些變化,因為除了要生成真實圖
像之外,還要保證生成的圖像和輸入圖像是匹配的。
于是D的輸入就做了一些變動。
D中要輸入成對的圖像。這類似于conditonal GAN
Pix2Pix中的D被論文中被實現為Patch-D,所謂Patch,是
指無論生成的圖像有多大,將其切分為多個固定大小的
Patch輸入進D去判斷。
這樣設計的好處是: D的輸入變小,計算量小,訓練速度快。
D網絡損失函數:
輸入真實的成對圖像希望判定為1.
輸入生成圖像與原圖像希望判定為0 G網絡損失函數:
輸入生成圖像與原圖像希望判定為1
對于圖像翻譯任務而言,G的輸入和輸出之間其實共享了很
多信息,比如圖像上色任務,輸入和輸出之間就共享了邊信
息。因而為了保證輸入圖像和輸出圖像之間的相似度,還加
入了L1 Loss
cGAN,輸入為圖像而不是隨機向量
U-Net,使用skip-connection來共享更多的信息
Pair輸入到D來保證映射
Patch-D來降低計算量提升效果
L1損失函數的加入來保證輸入和輸出之間的一致性.
(論文地址: https://phillipi.github.io/pix2pix/)
所使用的版本,是原數據集的一部分。
數據集中 語義分割圖 與 原始圖像 一起顯示在圖片中。這是
用于語義分割任務的最佳數據集之一。
數據集包含 2975 張訓練圖片和 500 張驗證圖片。
每個圖像文件是 256x512 像素,每張圖片都是一個組合,
圖像的左半部分是原始照片,
右半部分是標記圖像(語義分割輸出)
代碼
import tensorflow as tf import os import glob from matplotlib import pyplot as plt %matplotlib inline import time from IPython import display imgs_path = glob.glob(r'D:\163\gan20\pix2pix\datasets\cityscapes_data\train\*.jpg') def read_jpg(path):img = tf.io.read_file(path)img = tf.image.decode_jpeg(img, channels=3)return img def normalize(input_image, input_mask):input_image = tf.cast(input_image, tf.float32)/127.5 - 1input_mask = tf.cast(input_mask, tf.float32)/127.5 - 1return input_image, input_mask def load_image(image_path):image = read_jpg(image_path)w = tf.shape(image)[1]w = w // 2input_image = image[:, :w, :]input_mask = image[:, w:, :]input_image = tf.image.resize(input_image, (64, 64))input_mask = tf.image.resize(input_mask, (64, 64))if tf.random.uniform(()) > 0.5:input_image = tf.image.flip_left_right(input_image)input_mask = tf.image.flip_left_right(input_mask)input_image, input_mask = normalize(input_image, input_mask)return input_mask, input_image dataset = tf.data.Dataset.from_tensor_slices(imgs_path) train = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) BATCH_SIZE = 8 BUFFER_SIZE = 100 train_dataset = train.shuffle(BUFFER_SIZE).batch(BATCH_SIZE) train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) plt.figure(figsize=(5, 2)) for img, musk in train_dataset.take(1):plt.subplot(1,2,1)plt.imshow(tf.keras.preprocessing.image.array_to_img(img[0]))plt.subplot(1,2,2)plt.imshow(tf.keras.preprocessing.image.array_to_img(musk[0])) imgs_path_test = glob.glob(r'D:\163\gan20\pix2pix\datasets\cityscapes_data\val\*.jpg') dataset_test = tf.data.Dataset.from_tensor_slices(imgs_path_test) def load_image_test(image_path):image = read_jpg(image_path)w = tf.shape(image)[1]w = w // 2input_image = image[:, :w, :]input_mask = image[:, w:, :]input_image = tf.image.resize(input_image, (64, 64))input_mask = tf.image.resize(input_mask, (64, 64))input_image, input_mask = normalize(input_image, input_mask)return input_mask, input_image dataset_test = dataset_test.map(load_image_test) dataset_test = dataset_test.batch(BATCH_SIZE) plt.figure(figsize=(5, 2)) for img, musk in dataset_test.take(1):plt.subplot(1,2,1)plt.imshow(tf.keras.preprocessing.image.array_to_img(img[0]))plt.subplot(1,2,2)plt.imshow(tf.keras.preprocessing.image.array_to_img(musk[0])) OUTPUT_CHANNELS = 3 def downsample(filters, size, apply_batchnorm=True): # initializer = tf.random_normal_initializer(0., 0.02)result = tf.keras.Sequential()result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',use_bias=False))if apply_batchnorm:result.add(tf.keras.layers.BatchNormalization())result.add(tf.keras.layers.LeakyReLU())return result def upsample(filters, size, apply_dropout=False): # initializer = tf.random_normal_initializer(0., 0.02)result = tf.keras.Sequential()result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2,padding='same',use_bias=False))result.add(tf.keras.layers.BatchNormalization())if apply_dropout:result.add(tf.keras.layers.Dropout(0.5))result.add(tf.keras.layers.ReLU())return result def Generator():inputs = tf.keras.layers.Input(shape=[64,64,3])down_stack = [downsample(32, 3, apply_batchnorm=False), # (bs, 32, 32, 32)downsample(64, 3), # (bs, 16, 16, 64)downsample(128, 3), # (bs, 8, 8, 128)downsample(256, 3), # (bs, 4, 4, 256)downsample(512, 3), # (bs, 2, 2, 512)downsample(512, 3), # (bs, 1, 1, 512)]up_stack = [upsample(512, 3, apply_dropout=True), # (bs, 2, 2, 1024)upsample(256, 3, apply_dropout=True), # (bs, 4, 4, 512)upsample(128, 3, apply_dropout=True), # (bs, 8, 8, 256)upsample(64, 3), # (bs, 16, 16, 128)upsample(32, 3), # (bs, 32, 32, 64)]# initializer = tf.random_normal_initializer(0., 0.02)last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 3,strides=2,padding='same',activation='tanh') # (bs, 64, 64, 3)x = inputs# Downsampling through the modelskips = []for down in down_stack:x = down(x)skips.append(x)skips = reversed(skips[:-1])# Upsampling and establishing the skip connectionsfor up, skip in zip(up_stack, skips):x = up(x)x = tf.keras.layers.Concatenate()([x, skip])x = last(x)return tf.keras.Model(inputs=inputs, outputs=x) generator = Generator() #tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64) LAMBDA = 10 def generator_loss(disc_generated_output, gen_output, target):gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)# mean absolute errorl1_loss = tf.reduce_mean(tf.abs(target - gen_output))total_gen_loss = gan_loss + (LAMBDA * l1_loss)return total_gen_loss, gan_loss, l1_loss def Discriminator(): # initializer = tf.random_normal_initializer(0., 0.02)inp = tf.keras.layers.Input(shape=[64, 64, 3], name='input_image')tar = tf.keras.layers.Input(shape=[64, 64, 3], name='target_image')x = tf.keras.layers.concatenate([inp, tar]) # (bs, 64, 64, channels*2)down1 = downsample(32, 3, False)(x) # (bs, 32, 32, 32)down2 = downsample(64, 3)(down1) # (bs, 16, 16, 64)down3 = downsample(128, 3)(down2) # (bs, 8, 8, 128)conv = tf.keras.layers.Conv2D(256, 3, strides=1,padding='same',use_bias=False)(down3) # (bs, 8, 8, 256)batchnorm1 = tf.keras.layers.BatchNormalization()(conv)leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)last = tf.keras.layers.Conv2D(1, 3, strides=1)(leaky_relu) # (bs, 8, 8, 1)return tf.keras.Model(inputs=[inp, tar], outputs=last) discriminator = Discriminator() #tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64) loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True) def discriminator_loss(disc_real_output, disc_generated_output):real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)total_disc_loss = real_loss + generated_lossreturn total_disc_loss generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) def generate_images(model, test_input, tar):prediction = model(test_input, training=True)plt.figure(figsize=(7, 2))display_list = [test_input[0], tar[0], prediction[0]]title = ['Input Image', 'Ground Truth', 'Predicted Image']for i in range(3):plt.subplot(1, 3, i+1)plt.title(title[i])# getting the pixel values between [0, 1] to plot it.plt.imshow(display_list[i] * 0.5 + 0.5)plt.axis('off')plt.show() for example_input, example_target in dataset_test.take(1):generate_images(generator, example_input, example_target) EPOCHS = 110 @tf.function def train_step(input_image, target, epoch):with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:gen_output = generator(input_image, training=True)disc_real_output = discriminator([input_image, target], training=True)disc_generated_output = discriminator([input_image, gen_output], training=True)gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)disc_loss = discriminator_loss(disc_real_output, disc_generated_output)generator_gradients = gen_tape.gradient(gen_total_loss,generator.trainable_variables)discriminator_gradients = disc_tape.gradient(disc_loss,discriminator.trainable_variables)generator_optimizer.apply_gradients(zip(generator_gradients,generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.trainable_variables)) def fit(train_ds, epochs, test_ds):for epoch in range(epochs+1):if epoch%10 == 0:for example_input, example_target in test_ds.take(1):generate_images(generator, example_input, example_target)print("Epoch: ", epoch)for n, (input_image, target) in train_ds.enumerate():if n%10 == 0:print('.', end='')train_step(input_image, target, epoch)print() fit(train_dataset, EPOCHS, dataset_test)
總結
以上是生活随笔為你收集整理的GAN生成对抗网络-PIX2PIXGAN原理与基本实现-图像翻译09的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: GAN生成对抗网络-SSGAN原理与基本
- 下一篇: GAN生成对抗网络-CycleGAN原理