شبکه GAN مخفف generative adversarial network که در فارسی با عنوان شبکه های مولد تخاصمی نیز شناخته میشوند. این شبکه سال 2014 توسط گودفلو در این مقاله معرفی شد.
شبکه GAN یکی از معماری های قوی یادگیری عمیق است که بسیار پرکاربرد است، قابلیت اصلی این شبکه مانند سایر شبکه های مولد، تولید داده است. بطور مثال ایجاد تصاویر جدید برای دیتابیس تصاویر؛
ساختار شبکه GAN
در تصویر زیر با ساختار این شبکه بیشتر آشنا می شوید:
شبکه GAN از دوبخش مولد(Generator) و تفکیک کننده(Discriminator) تشکیل شده است که هر کدام از این دوبخش یک شبکه عصبی هستند. آموزش آنها با الگوریتم پس انتشار خطا (Backpropagation) و با خطای شبکه تفکیک کننده (تابع هزینه) انجام میگیرد.
بخش مولد به دنبال تولید داده هایی است که بخش تفکیک کننده نتواند آن ها را بعنوان تقلبی شناسایی کند. در مقابل بخش تفکیک کننده به دنبال شناسایی تصاویر اصلی از تصاویر تقلبی است.
تابع هزینه شبکه GAN
تابع هزینه هرکدام از شبکه های مولد و تفکیک کنده به پارامترهای دو شبکه وابسته است. اما در فرایند آموزش هر شبکه فقط میتواند پارامترهای خودش را تنظیم کند؛
فرایند بهینه سازی تابع هزینه در GAN
در شبکه های عصبی معمول فرایند آموزش یک فرایند بهینه سازی (مینیمم سازی) تابع هزینه است.
در شبکه GAN دو تابع هزینه وجود دارد: هزینه مولد، هزینه تفکیک کننده.
هر دو این هزینه ها بر اساس خطای تفکیک کننده تعریف میشوند.
مولد به دنبال ماکسیمم سازی خطای تفکیک کننده و تفکیک کننده به دنبال مینیمم سازی!
بهینه سازی رقابتی (تئوری بازی)
در شبکه GAN جواب بهینه زمانی است که هیچکدام از دو شبکه نتوانند نتیجه خود را بهبود دهند، یعنی زمانی که داده های تقلبی تولید شده توسط شبکه مولد از داده های اصلی غیرقابل شناسایی است و شبکه تفکیک کننده در بهترین حالت میتواند به صورت 50/50 تشخیص دهد!
فرایند کلی آموزش GAN
آموزش شبکه تفکیک کننده
- انتخاب تعدادی از داده های آموزشی اصلی به صورت تصادفی (x)
- ایجاد تعدادی بردار تصادفی نویز و تولید تعدادی نمونه تقلبی (*G(z)=x)
- محاسبه خطای x و *x و استفاده از آن برای آموزش وزن های شبکه تفکیک کننده و مینیمم سازی خطای شبکه تفکیک کننده
نکته: وزن شبکه های مولد ثابت است.
آموزش شبکه مولد
- ایجاد تعدادی بردار تصادفی نویز و تولید تعدادی نمونه تقلبی (*G(z)=x)
- محاسبه خطای *x و استفاده از آن برای آموزش وزن های شبکه مولد و ماکزیمم سازی خطای شبکه تفکیک کننده
نکته: وزن های شبکه تفکیک کننده ثابت است.
نکاتی در مورد آموزش شبکه های GAN
نرمال سازی: تصاویر ورودی به شبکه تفکیک کننده بین 1- و 1 نرمال شوند.
نرمالسازی تکهای (Batch Normalization): خروجی هر لایه شبکه نرمالسازی شود.
آموزش بیشتر تفکیک کننده: پیشآموزش(pretrain) شبکه تفکیک کننده قبل از شبکه مولد. و یا آموزش 5 به 1 تفکیک کننده نسبت به مولد؛
اجتناب از گرادیان های خلوت: استفاده از Maxpooling یا Relu ممکن است سبب این موضوع شود اما LeakyRelu خیر.
اجتناب از گرادیان های بزرگ: محدود کردن گرادیان با توجه به ورودی ها.
یادگیری در شبکه های GAN
در شبکه GAN یادگیری نظارت نشده (unsupervised) است زیرا داده های اصلی ورودی برچسب(لیبل) ندارند.
مقایسه شبکه GAN و Autoencoder
- هر دو شبکه قابلیت تولید داده دارند.
- هر دو شبکه یادگیری نظارت نشده دارند.
- شبکه GAN دو تابع هزینه و Autoencoder یک تابع هزینه دارد.
پیاده سازی شبکه GAN در پایتون
در زیر یک نمونه ساده از پیاده سازی یک شبکه GAN در پایتون برای تولید اعداد دستنویس (دیتاست mnist) آورده شده است. کدها و توضیحات این مطلب از آموزش مهندس قاضیخانی در سایت فرادرس جمعآوری شده است:
[py]
import matplotlib.pyplot as plt
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense,Flatten,Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam
img_rows=28
img_cols=28
channels=1
img_shape = (img_rows,img_cols,channels)
zdim=100
def build_gen(img_shape,zdim):
model = Sequential()
model.add(Dense(128,input_dim=zdim))
model.add(LeakyReLU(alpha=0.01))
model.add(Dense(28*28*1,activation=’tanh’))
model.add(Reshape(img_shape))
return model
def build_dis(img_shape):
model=Sequential()
model.add(Flatten(input_shape=img_shape))
model.add(Dense(128))
model.add(LeakyReLU(alpha=0.01))
model.add(Dense(1,activation=’sigmoid’))
return model
def build_gan(gen,dis):
model = Sequential()
model.add(gen)
model.add(dis)
return model
dis_v = build_dis(img_shape)
dis_v.compile(loss=’binary_crossentropy’,
optimizer=Adam(),
metrics=[‘accuracy’])
gen_v = build_gen(img_shape,zdim)
dis_v.trainable=False
gan_v = build_gan(gen_v,dis_v)
gan_v.compile(loss=’binary_crossentropy’,
optimizer=Adam()
)
losses=[]
accuracies=[]
iteration_checks=[]
def train(iterations,batch_size,interval):
(Xtrain, _),(_, _) = mnist.load_data()
Xtrain = Xtrain/127.5 – 1.0
Xtrain = np.expand_dims(Xtrain,axis=3)
real = np.ones((batch_size,1))
fake = np.zeros((batch_size, 1))
for iteration in range(iterations):
ids = np.random.randint(0,Xtrain.shape[0],batch_size)
imgs = Xtrain[ids]
z=np.random.normal(0,1,(batch_size,100))
gen_imgs = gen_v.predict(z)
dloss_real = dis_v.train_on_batch(imgs,real)
dloss_fake = dis_v.train_on_batch(gen_imgs, fake)
dloss,accuracy = 0.5 * np.add(dloss_real,dloss_fake)
z = np.random.normal(0, 1, (batch_size, 100))
gloss = gan_v.train_on_batch(z,real)
if (iteration+1) % interval == 0:
losses.append((dloss,gloss))
accuracies.append(100.0*accuracy)
iteration_checks.append(iteration+1)
print("%d [D loss: %f , acc: %.2f] [G loss: %f]" %
(iteration+1,dloss,100.0*accuracy,gloss))
show_images(gen_v)
def show_images(gen):
z = np.random.normal(0, 1, (16, 100))
gen_imgs = gen.predict(z)
gen_imgs = 0.5*gen_imgs + 0.5
fig,axs = plt.subplots(4,4,figsize=(4,4),sharey=True,sharex=True)
cnt=0
for i in range(4):
for j in range(4):
axs[i, j].imshow(gen_imgs[cnt,:,:,0],cmap=’gray’)
axs[i, j].axis(‘off’)
cnt+=1
fig.show()
train(5000,128,1000)
[/py]