当前位置: 首页 > news >正文

【深度学习】gan网络原理实现猫狗分类

【深度学习】gan网络原理实现猫狗分类

GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。
在这里插入图片描述

1.下载数据并对数据进行规范

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5 , 0.5)
])
train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

下载MNIST数据集,并对数据进行规范化。transforms.Compose 是用于定义一系列数据变换的类,ToTensor() 将图像转换为PyTorch张量,Normalize(0.5, 0.5) 对张量进行归一化。然后,创建一个 DataLoader,它将数据集划分成小批次,使得在训练时更容易处理。

2.生成器的代码

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 28*28),nn.Tanh())def forward(self, x):img = self.main(x)img = img.reshape(-1, 28, 28)return img

这一部分定义了生成器的神经网络模型。生成器的输入是一个大小为100的随机向量,通过多个线性层和激活函数(ReLU),最后通过 nn.Tanh() 激活函数生成大小为28x28的图像。forward 方法定义了前向传播的过程。

3.判别器的代码

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(28*28, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 28*28)x = self.main(x)return x

这一部分定义了判别器的神经网络模型。判别器的输入是28x28大小的图像,通过多个线性层和激活函数(LeakyReLU),最后通过 nn.Sigmoid() 激活函数输出一个0到1之间的值,表示输入图像是真实图像的概率。

4. 定义损失函数和优化函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
gen_opt = optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()

这一部分设置了设备(GPU或CPU)、初始化了生成器和判别器的实例,并定义了优化器(Adam优化器)和损失函数(二分类交叉熵损失)。将生成器和判别器移动到设备上进行加速计算。

5.定义绘图函数

def gen_img_plot(model,test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i+1)plt.imshow((prediction[i]+1)/2)plt.axis('off')plt.show()

6. 开始训练,并显示出生成器所产生的图像

test_input = torch.randn(16, 100, device=device)
D_loss = []
G_loss = []
for epoch in range(30):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader)for step, (img, _) in enumerate(dataloader):img = img.to(device)               # 获得用于训练的mnist图像size = img.size(0)                 # 获得1批次数据量大小# 随机生成size个100维的向量样本值,也即是噪声,用于输入生成器 生成 和mnist一样的图像数据random_noise = torch.randn(size, 100, device=device)########################### 先训练判别器 #############################dis_opt.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 真实值的loss,也即是真图片与1标签的损失d_real_loss.backward()gen_img = gen(random_noise)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 假的值的loss,也即是生成的图像与0标签的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossdis_opt.step()########################### 下面再训练生成器 #############################gen_opt.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()#########################################################################with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_loss
with torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('epoch:', epoch)gen_img_plot(gen, test_input)

1.设置 test_input 作为模型的输入,并初始化用于存储判别器(D)和生成器(G)的损失值的列表。

2.开始 30 轮次的训练循环。在每一轮中:

3.对数据集进行遍历。每次迭代,加载一批图像数据 (img)。

4.将图像数据移动到设备(device)上,并获取批次大小。

5.生成随机噪声,作为输入给生成器。

6.训练判别器(D):

  • 对真实图像计算判别器的损失 (d_real_loss),并反向传播计算梯度。
  • 生成生成器产生的图像,并计算判别器的对这些生成图像的损失 (d_fake_loss),再反向传播计算梯度。
  • 计算总的判别器损失 d_loss,并更新判别器的参数。

7.训练生成器(G):

  • 生成器生成图像,并将其输入到判别器中,计算生成器的损失 (g_loss),并反向传播计算梯度。
  • 更新生成器的参数。

这个过程是 GAN 中交替训练生成器和判别器的典型过程,目的是让生成器生成逼真的图像,同时让判别器能够准确区分真假图像。

相关文章:

【深度学习】gan网络原理实现猫狗分类

【深度学习】gan网络原理实现猫狗分类 GAN的基本思想源自博弈论你的二人零和博弈,由一个生成器和一个判别器构成,通过对抗学习的方式训练,目的是估测数据样本的潜在分布并生成新的数据样本。 1.下载数据并对数据进行规范 transform tran…...

⑨【Stream】Redis流是什么?怎么用?: Stream [使用手册]

个人简介:Java领域新星创作者;阿里云技术博主、星级博主、专家博主;正在Java学习的路上摸爬滚打,记录学习的过程~ 个人主页:.29.的博客 学习社区:进去逛一逛~ ⑨Redis Stream基本操作命令汇总 一、Redis流 …...

浙江启用无人机巡山护林模式,火灾扑救效率高

为了保护天然的森林资源,浙江当地林业部门引入了一种创新技术:林业无人机。这些天空中的守护者正在重新定义森林防火和护林工作的方式。 当下正值天气干燥的季节,这些无人机开始了它们的首次大规模任务。它们在指定的林区内自主巡逻&#xff…...

Starrocks异步物化视图的使用以及注意事项

最近在使用starrocks来进行实时数据项目的开发,尝试使用了一下starrocks的异步物化视图。 使用版本: 3.1.2-4f3a2ee 创建三个测试表, 注意只有test_mv_table1为分区表,其他两个都是非分区表: CREATE TABLE test_mv_table1 (periodday DATE NOT NULL CO…...

SpringBoot整合Sharding-Jdbc实现分库分表和分布式全局id

SpringBoot整合Sharding-Jdbc Sharding-Jdbc sharding-jdbc是客户端代理的数据库中间件;它和MyCat最大的不同是sharding-jdbc支持库内分表。 整合 数据库环境 在两台不同的主机上分别都创建了sharding_order数据库,库中都有t_order_1和t_order_2两张…...

「江鸟中原」有关HarmonyOS-ArkTS的Http通信请求

一、Http简介 HTTP(Hypertext Transfer Protocol)是一种用于在Web应用程序之间进行通信的协议,通过运输层的TCP协议建立连接、传输数据。Http通信数据以报文的形式进行传输。Http的一次事务包括一个请求和一个响应。 Http通信是基于客户端-服…...

vuex的使用笔记

1.安装 npm安装 npm install vuexnext --saveyarn安装 yarn add vuexnext --save2.基本结构 import Vuex from vuexconst store createStore({ //状态:相当于vue中的data() state() {return {name: 0,code:"",todos: [{ id: 1…...

汇编:关于栈的知识

1.入栈和出栈指令 2. SS与SP 3. 入栈与出栈 3.1 执行push ax ↑↑ 3.2 执行pop ax ↓↓ 3.3 栈顶超界的问题 4. 寄存器赋值 基于8086CPU编程时,可以将一段内存当作栈来使用。一个栈段最大可以设为64KB(0-FFFFH)。 1.入栈和出栈指令…...

uniapp使用map标签

在UniApp中,可以使用map标签来显示地图,并通过其属性来自定义地图的样式和行为。以下是一些常用的map标签属性: id:用于给地图组件指定一个唯一的标识符,方便在代码中进行引用和操作。 style:用来设置地图…...

MacOS14 Sonoma 安装 Flutter 开发环境

本文针对 小白用户也包括自己,以前都是将这些写入我的有道云笔记。为了让给多人看见或者说自己更好的浏览,先将其记录如下。 朋友介绍一个项目说要开发一款App,最近也是闲着就答应下来。主要功能是通过蓝牙BLE控制设备的一个 Iot边缘设备&…...

【Web】PHP反序列化刷题记录

目录 ①[NISACTF 2022]babyserialize ②[NISACTF 2022]popchains ③[SWPUCTF 2022 新生赛]ez_ez_unserialize ④[GDOUCTF 2023]反方向的钟 再巩固下基础 ①[NISACTF 2022]babyserialize <?php include "waf.php"; class NISA{public $fun"show_me_fla…...

C++标准模板库 STL 简介(standard template library)

在 C 语言中&#xff0c;很多东西都是由我们自己去实现的&#xff0c;例如自定义数组&#xff0c;线程文件操作&#xff0c;排序算法等等&#xff0c;有些复杂的东西实现不好很容易留下不易发现的 bug。而 C为使用者提供了一套标准模板库 STL,其中封装了很多实用的容器&#xf…...

Linux篇:文件系统

一、共识原理&#xff1a; 文件文件内容文件属性 磁盘上存储文件存文件的内容&#xff08;数据块&#xff09;存文件的属性&#xff08;inode&#xff09; Linux的文件在磁盘中存储是将属性和内容分开存储的。 二、硬件简述&#xff1a; 1. 认识硬件 磁盘&#xff1a;唯一的一…...

AI - Crowd Simulation(集群模拟)

类似鱼群&#xff0c;鸟群这种群体运动模拟。 是Microscopic Models 微观模型&#xff0c;定义每一个个体的行为&#xff0c;然后合在一起。 主要是根据一定范围内族群其他对象的运动状态决定自己的运动状态 Cohesion 保证个体不会脱离群体 求物体一定半径范围内的其他临近物…...

<JavaEE> Java中线程有多少种状态(State)?状态之间的关系有什么关系?

目录 一、系统内核中的线程状态 二、Java中的线程状态 一、系统内核中的线程状态 状态说明就绪状态线程已经准备就绪&#xff0c;随时可以接受CPU的调度。阻塞状态线程处于阻塞等待&#xff0c;暂时无法在CPU中执行。 二、Java中的线程状态 相比于系统内核&#xff0c;Java…...

正则表达式 通配符 awk文本处理工具

目录 什么是正则表达式 概念 正则表达式的结构 正则表达式的组成 元字符 元字符点&#xff08;.&#xff09; 代表字符. 点值表示点需要转义 \ r..t 代表r到t之间任意两个字符 过滤出小写 过滤出非小写 space空格 [[:space:]] 表示次数 位置锚定 例&#xff1a…...

三、ts高级笔记,

文章目录 18、d.ts声明文件19、Mixin混入20、Decorator装饰器的使用21、-高级proxy拦截_Reflect元储存22、-高级写法Partial-Pick23、Readonly只读_Record套对象24、高阶写法Infer占位符25、Inter实现提取类型和倒叙递归26、object、Object、{}的区别27、localStorage封装28、协…...

二十一、数组(6)

本章概要 数组排序Arrays.sort的使用并行排序binarySearch二分查找parallelPrefix并行前缀 数组排序 根据对象的实际类型执行比较排序。一种方法是为不同的类型编写对应的排序方法&#xff0c;但是这样的代码不能复用。 编程设计的一个主要目标是“将易变的元素与稳定的元素…...

flask依据现有的库表快速生成flask实体类

flask依据现有的库表快速生成flask实体类 在实际开发过程中&#xff0c;flask的sqlalchemy对应的model类写起来重复性较强&#xff0c;如果表比较多会比较繁琐&#xff0c;这个时候可以使用 flask-sqlacodegen 来快速的生成model程序或者py文件&#xff0c;以下是简单的示例&a…...

.NET6 开发一个检查某些状态持续多长时间的类

📢欢迎点赞 :👍 收藏 ⭐留言 📝 如有错误敬请指正,赐人玫瑰,手留余香!📢本文作者:由webmote 原创📢作者格言:新的征程,我们面对的不仅仅是技术还有人心,人心不可测,海水不可量,唯有技术,才是深沉黑夜中的一座闪烁的灯塔 !序言 在代码的世界里,时常碰撞…...

突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合

强化学习&#xff08;Reinforcement Learning, RL&#xff09;是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程&#xff0c;然后使用强化学习的Actor-Critic机制&#xff08;中文译作“知行互动”机制&#xff09;&#xff0c;逐步迭代求解…...

VB.net复制Ntag213卡写入UID

本示例使用的发卡器&#xff1a;https://item.taobao.com/item.htm?ftt&id615391857885 一、读取旧Ntag卡的UID和数据 Private Sub Button15_Click(sender As Object, e As EventArgs) Handles Button15.Click轻松读卡技术支持:网站:Dim i, j As IntegerDim cardidhex, …...

连锁超市冷库节能解决方案:如何实现超市降本增效

在连锁超市冷库运营中&#xff0c;高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术&#xff0c;实现年省电费15%-60%&#xff0c;且不改动原有装备、安装快捷、…...

Unit 1 深度强化学习简介

Deep RL Course ——Unit 1 Introduction 从理论和实践层面深入学习深度强化学习。学会使用知名的深度强化学习库&#xff0c;例如 Stable Baselines3、RL Baselines3 Zoo、Sample Factory 和 CleanRL。在独特的环境中训练智能体&#xff0c;比如 SnowballFight、Huggy the Do…...

在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?

uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件&#xff0c;用于在原生应用中加载 HTML 页面&#xff1a; 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...

【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论

路径问题的革命性重构&#xff1a;基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中&#xff08;图1&#xff09;&#xff1a; mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...

DingDing机器人群消息推送

文章目录 1 新建机器人2 API文档说明3 代码编写 1 新建机器人 点击群设置 下滑到群管理的机器人&#xff0c;点击进入 添加机器人 选择自定义Webhook服务 点击添加 设置安全设置&#xff0c;详见说明文档 成功后&#xff0c;记录Webhook 2 API文档说明 点击设置说明 查看自…...

Webpack性能优化:构建速度与体积优化策略

一、构建速度优化 1、​​升级Webpack和Node.js​​ ​​优化效果​​&#xff1a;Webpack 4比Webpack 3构建时间降低60%-98%。​​原因​​&#xff1a; V8引擎优化&#xff08;for of替代forEach、Map/Set替代Object&#xff09;。默认使用更快的md4哈希算法。AST直接从Loa…...

基于Springboot+Vue的办公管理系统

角色&#xff1a; 管理员、员工 技术&#xff1a; 后端: SpringBoot, Vue2, MySQL, Mybatis-Plus 前端: Vue2, Element-UI, Axios, Echarts, Vue-Router 核心功能&#xff1a; 该办公管理系统是一个综合性的企业内部管理平台&#xff0c;旨在提升企业运营效率和员工管理水…...

MacOS下Homebrew国内镜像加速指南(2025最新国内镜像加速)

macos brew国内镜像加速方法 brew install 加速formula.jws.json下载慢加速 &#x1f37a; 最新版brew安装慢到怀疑人生&#xff1f;别怕&#xff0c;教你轻松起飞&#xff01; 最近Homebrew更新至最新版&#xff0c;每次执行 brew 命令时都会自动从官方地址 https://formulae.…...