网页的网站建设在哪里,如何做网站规划,山东济南城乡建设厅网站,网络营销课程报告为了便于理解#xff0c;可以先玩一玩这个网站#xff1a;GAN Lab: Play with Generative Adversarial Networks in Your Browser! GAN的本质#xff1a;枯叶蝶和鸟。生成器的目标#xff1a;让枯叶蝶进化#xff0c;变得像枯叶#xff0c;不被鸟准确识别。判别器的目标可以先玩一玩这个网站GAN Lab: Play with Generative Adversarial Networks in Your Browser! GAN的本质枯叶蝶和鸟。生成器的目标让枯叶蝶进化变得像枯叶不被鸟准确识别。判别器的目标准确判别是枯叶还是鸟
伪代码 案例
原始数据 案例结果 案例完整代码
# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams[font.sans-serif] [SimHei] # 显示中文标签
plt.rcParams[axes.unicode_minus] False# dir ... your path/faces/
dir ./data/train_data
# path []
#
# for fileName in os.listdir(dir):
# path.append(fileName) # len(path)51223noiseSize 100 # 噪声维度
n_generator_feature 64 # 生成器feature map数
n_discriminator_feature 64 # 判别器feature map数
batch_size 50
d_every 1 # 每一个batch训练一次discriminator
g_every 5 # 每五个batch训练一次generatorclass NetGenerator(nn.Module):def __init__(self):super(NetGenerator,self).__init__()self.main nn.Sequential( # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size4, stride1, padding0, biasFalse),#转置卷积层输入特征映射的尺寸会放大通道数可能会减小普通卷积层输入特征映射的尺寸会缩小但通道数可能会增加nn.BatchNorm2d(n_generator_feature * 8),nn.ReLU(True), # (n_generator_feature * 8) × 4 × 4 (1-1)*11*(4-1)01 4nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size4, stride2, padding1, biasFalse),nn.BatchNorm2d(n_generator_feature * 4),nn.ReLU(True), # (n_generator_feature * 4) × 8 × 8 (4-1)*2-2*11*(4-1)01 8nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size4, stride2, padding1, biasFalse),nn.BatchNorm2d(n_generator_feature * 2),nn.ReLU(True), # (n_generator_feature * 2) × 16 × 16nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size4, stride2, padding1, biasFalse),nn.BatchNorm2d(n_generator_feature),nn.ReLU(True), # (n_generator_feature) × 32 × 32nn.ConvTranspose2d(n_generator_feature, 3, kernel_size5, stride3, padding1, biasFalse),nn.Tanh() # 3 * 96 * 96)def forward(self, input):return self.main(input)class NetDiscriminator(nn.Module):def __init__(self):super(NetDiscriminator,self).__init__()self.main nn.Sequential(nn.Conv2d(3, n_discriminator_feature, kernel_size5, stride3, padding1, biasFalse),nn.LeakyReLU(0.2, inplaceTrue), # n_discriminator_feature * 32 * 32nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size4, stride2, padding1, biasFalse),nn.BatchNorm2d(n_discriminator_feature * 2),nn.LeakyReLU(0.2, inplaceTrue), # (n_discriminator_feature*2) * 16 * 16nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size4, stride2, padding1, biasFalse),nn.BatchNorm2d(n_discriminator_feature * 4),nn.LeakyReLU(0.2, inplaceTrue), # (n_discriminator_feature*4) * 8 * 8nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size4, stride2, padding1, biasFalse),nn.BatchNorm2d(n_discriminator_feature * 8),nn.LeakyReLU(0.2, inplaceTrue), # (n_discriminator_feature*8) * 4 * 4nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size4, stride1, padding0, biasFalse),nn.Sigmoid() # 输出一个概率)def forward(self, input):return self.main(input).view(-1)def train():for i, (image,_) in tqdm.tqdm(enumerate(dataloader)): # type((image,_)) class list, len((image,_)) 2 * 256 * 3 * 96 * 96real_image Variable(image)#real_image real_image.cuda()if (i 1) % d_every 0: #d_every 1每一个batch训练一次discriminatoroptimizer_d.zero_grad()output Discriminator(real_image) # 尽可能把真图片判为Trueerror_d_real criterion(output, true_labels)error_d_real.backward()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img Generator(noises).detach() # 根据噪声生成假图fake_output Discriminator(fake_img) # 尽可能把假图片判为Falseerror_d_fake criterion(fake_output, fake_labels)error_d_fake.backward()optimizer_d.step()if (i 1) % g_every 0:optimizer_g.zero_grad()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img Generator(noises) # 这里没有detachfake_output Discriminator(fake_img) # 尽可能让Discriminator把假图片判为Trueerror_g criterion(fake_output, true_labels)error_g.backward()optimizer_g.step()def show(num):fix_fake_imags Generator(fix_noises)fix_fake_imags fix_fake_imags.data.cpu()[:64] * 0.5 0.5# x torch.rand(64, 3, 96, 96)fig plt.figure(1)i 1for image in fix_fake_imags:ax fig.add_subplot(8, 8, eval(%d % i)) #将Figure划分为8行8列的子图网格并将当前的子图添加到第i个位置。# plt.xticks([]), plt.yticks([]) # 去除坐标轴plt.axis(off)plt.imshow(image.permute(1, 2, 0)) #permute()函数可以对维度进行重排,Matplotlib期望的图像格式是(H, W, C)即高度、宽度、通道i 1plt.subplots_adjust(leftNone, # the left side of the subplots of the figurerightNone, # the right side of the subplots of the figurebottomNone, # the bottom of the subplots of the figuretopNone, # the top of the subplots of the figurewspace0.05, # the amount of width reserved for blank space between subplotshspace0.05) # the amount of height reserved for white space between subplots)plt.suptitle(第%d迭代结果 % num, y0.91, fontsize15)plt.savefig(images/%dcgan.png % num)if __name__ __main__:transform tv.transforms.Compose([tv.transforms.Resize(96), # 图片尺寸, transforms.Scale transform is deprecatedtv.transforms.CenterCrop(96),tv.transforms.ToTensor(),tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 变成[-1,1]的数])dataset tv.datasets.ImageFolder(dir, transformtransform)dataloader torch.utils.data.DataLoader(dataset, batch_sizebatch_size, shuffleTrue, num_workers4, drop_lastTrue) # module torch.utils.data has no attribute DataLoderprint(数据加载完毕)Generator NetGenerator()Discriminator NetDiscriminator()optimizer_g torch.optim.Adam(Generator.parameters(), lr2e-4, betas(0.5, 0.999))optimizer_d torch.optim.Adam(Discriminator.parameters(), lr2e-4, betas(0.5, 0.999))criterion torch.nn.BCELoss()true_labels Variable(torch.ones(batch_size)) # batch_sizefake_labels Variable(torch.zeros(batch_size))fix_noises Variable(torch.randn(batch_size, noiseSize, 1, 1))noises Variable(torch.randn(batch_size, noiseSize, 1, 1)) # 均值为0方差为1的正态分布# if torch.cuda.is_available() True:# print(Cuda is available!)# Generator.cuda()# Discriminator.cuda()# criterion.cuda()# true_labels, fake_labels true_labels.cuda(), fake_labels.cuda()# fix_noises, noises fix_noises.cuda(), noises.cuda()#plot_epoch [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]plot_epoch [1,5,10,50,100,200,500,800,1000,1200,1500]for i in range(1500): # 最大迭代次数train()print(迭代次数{}.format(i))if i in plot_epoch:show(i)
http://t.csdnimg.cn/FTSrihttp://t.csdnimg.cn/FTSri