个人视频网站注册平台,唐山市住房和城乡建设局网站,wordpress 自己做主题,网站访问量统计工具简介#xff1a;GAN生成对抗网络本质上是一种思想#xff0c;其依靠神经网络能够拟合任意函数的能力#xff0c;设计了一种架构来实现数据的生成。 原理#xff1a;GAN的原理就是最小化生成器Generator的损失#xff0c;但是在最小化损失的过程中加入了一个约束#xff0… 简介GAN生成对抗网络本质上是一种思想其依靠神经网络能够拟合任意函数的能力设计了一种架构来实现数据的生成。 原理GAN的原理就是最小化生成器Generator的损失但是在最小化损失的过程中加入了一个约束这个约束就是使Generator生成的数据满足我们指定数据的分布GAN的巧妙之处在于使用一个神经网络(鉴别器Discriminator)来自动判断生成的数据是否符合我们所需要的分布。 实现细节 一 准备好我们想要让生成器生成的数据类型比如MINIST手写数字集包含1-10十个数字一共60000张图片。生成器的目的就是学习这个数据集的分布。 二 定义一个生成器用于判别一张图片是实际的还是生成器生成的当生成器完美学习得到数据分布之后鉴别器可能就分不清图片是生成器的还是实际的这样的话生成器就能生成我们想要的图片了。 生成器的训练过程为实际数据输出结果1生成数据输出结果为0目的是学会区分真假数据相当于提供一个约束使生成数据符合指定分布。当鉴别生成器的数据分布时只需要更新鉴别器的参数权重不能够通过计算图将生成器的参数进行更新。 三 定义一个生成器给定一个输入他就能生成1-10里面的一个数字的图片。生成器的反向更新是根据鉴别器的损失来确定(被约束进行反向更新)。生成器的网络权重参数是单独的反向更新时只需要更新计算图当中属于生成器部分的参数。 下面给出生成1-0-1-0数据格式的代码 # %%
import torch
import numpy
import torch.nn as nn
import matplotlib.pyplot as plt# %%
def gennerate1010():return torch.FloatTensor([numpy.random.uniform(0.9,1.1),numpy.random.uniform(0.,.1),numpy.random.uniform(0.9,1.1),numpy.random.uniform(0.0,.1)])# %%
def genneratexxxx():return torch.rand(4)# %%
class Discrimer(nn.Module):def __init__(self) - None:father_obj super(Discrimer,self)father_obj.__init__()self.create_model()self.counter 0self.progress []def create_model(self):self.model nn.Sequential(nn.Linear(4,3),nn.Sigmoid(),nn.Linear(3,1),nn.Sigmoid(), )self.loss_functon nn.MSELoss()self.optimiser torch.optim.SGD(self.parameters(),lr0.01)def forward(self,x):return self.model(x)def train(self,x,targets):outputs self.forward(x)loss self.loss_functon(outputs,targets)self.counter 1if self.counter%10 0:self.progress.append(loss.item())if self.counter%10000 0:print(self.counter)self.optimiser.zero_grad()loss.backward()self.optimiser.step()def plotprogress(self):plt.plot(self.progress,marker*)plt.show()# %%
class Gennerater(nn.Module):def __init__(self) - None:father_obj super(Gennerater,self)father_obj.__init__()self.create_model()self.counter 0self.progress []def create_model(self):self.model nn.Sequential(nn.Linear(1,3),nn.Sigmoid(),nn.Linear(3,4),nn.Sigmoid(), )# 这个优化器只能优化生成器部分的参数self.optimiser torch.optim.SGD(self.parameters(),lr0.01)def forward(self,x):return self.model(x)def train(self,D,x,targets):g_outputs self.forward(x)d_outputs D.forward(g_outputs)# 使用鉴别器的loss函数但是只更新生成器的参数生成器的参数需要根据鉴别器的约束进行更新loss D.loss_functon(d_outputs,targets)self.counter 1if self.counter%10 0:self.progress.append(loss.item())if self.counter%10000 0:print(self.counter)self.optimiser.zero_grad()loss.backward()self.optimiser.step()def plotprogress(self):plt.plot(self.progress,marker*)plt.show()# %%
D Discrimer()# %%
G Gennerater()# %%
for id in range(15000):# 喂入实际数据给鉴别器D.train(gennerate1010(),torch.FloatTensor([1.]))# 喂入生成的数据使用detach从计算图脱离用于更新鉴别器而生成器得不到更新D.train(G.forward(torch.FloatTensor([0.5]).detach()),torch.FloatTensor([0.0]))G.train(D,torch.FloatTensor([0.5]),torch.FloatTensor([1.]))# %%
D.plotprogress()# %%
G.plotprogress()# %%
G.forward(torch.FloatTensor([0.5])) 参考PyTorch生成对抗网络编程