第G9周:ACGAN理论与实战
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
- 🚀 文章来源:K同学的学习圈子
上一周已经给出代码,需要可以跳转上一周的任务
第G8周:ACGAN任务
import argparse
import os
import numpy as npimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch# 创建用于存储生成图像的目录
os.makedirs("images", exist_ok=True)# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="训练的总轮数")
parser.add_argument("--batch_size", type=int, default=64, help="每个批次的大小")
parser.add_argument("--lr", type=float, default=0.0002, help="Adam优化器的学习率")
parser.add_argument("--b1", type=float, default=0.5, help="Adam优化器的一阶动量衰减")
parser.add_argument("--b2", type=float, default=0.999, help="Adam优化器的二阶动量衰减")
parser.add_argument("--n_cpu", type=int, default=4, help="用于批次生成的CPU线程数")
parser.add_argument("--latent_dim", type=int, default=100, help="潜在空间的维度")
parser.add_argument("--n_classes", type=int, default=10, help="数据集的类别数")
parser.add_argument("--img_size", type=int, default=32, help="每个图像的尺寸")
parser.add_argument("--channels", type=int, default=1, help="图像通道数")
parser.add_argument("--sample_interval", type=int, default=400, help="图像采样间隔")
opt = parser.parse_args()
print(opt)# 检查是否支持GPU加速
cuda = True if torch.cuda.is_available() else False# 初始化神经网络权重的函数
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)# 生成器网络类
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 为类别标签创建嵌入层self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)# 计算上采样前的初始大小self.init_size = opt.img_size // 4 # Initial size before upsampling# 第一层线性层self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))# 卷积层块self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise, labels):# 将标签嵌入到噪声中gen_input = torch.mul(self.label_emb(labels), noise)# 通过第一层线性层out = self.l1(gen_input)# 重新整形为合适的形状out = out.view(out.shape[0], 128, self.init_size, self.init_size)# 通过卷积层块生成图像img = self.conv_blocks(out)return img# 判别器网络类
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义判别器块的函数def discriminator_block(in_filters, out_filters, bn=True):"""返回每个判别器块的层"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return block# 判别器的卷积层块self.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样后图像的高度和宽度ds_size = opt.img_size // 2 ** 4# 输出层self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label# 损失函数
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()auxiliary_loss.cuda()# 初始化权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 配置数据加载器
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor# 保存生成图像的函数
def sample_image(n_row, batches_done):"""保存从0到n_classes的生成数字的图像网格"""# 采样噪声z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))# 为n行生成标签从0到n_classeslabels = np.array([num for _ in range(n_row) for num in range(n_row)])labels = Variable(LongTensor(labels))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)# ----------
# 训练
# ----------for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 真实数据的标签valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)# 生成数据的标签fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)# 配置输入real_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------# 训练生成器# -----------------optimizer_G.zero_grad()# 采样噪声和标签作为生成器的输入z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))# 生成一批图像gen_imgs = generator(z, gen_labels)# 损失度量生成器的欺骗判别器的能力validity, pred_label = discriminator(gen_imgs)g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))g_loss.backward()optimizer_G.step()# ---------------------# 训练判别器# ---------------------optimizer_D.zero_grad()# 真实图像的损失real_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# 生成图像的损失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# 判别器的总损失d_loss = (d_real_loss + d_fake_loss) / 2# 计算判别器的准确率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:sample_image(n_row=10, batches_done=batches_done)



相关文章:
第G9周:ACGAN理论与实战
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子 上一周已经给出代码,需要可以跳转上一周的任务 第G8周:ACGAN任…...
Linux网络部分——DNS域名解析服务
目录 1. 域名结构 2. 系统根据域名查找IP地址的过程 3.DNS域名解析方式 4.DNS域名解析的工作原理【☆】 5.域名解析查询方式 6.搭建主从DNS域名服务器 ①初始化操作主服务器和从服务器,安装BIND软件 ②修改主服务器的主配置文件、区域配置文件、区域数…...
预处理详解
乐观学习,乐观生活,才能不断前进啊!!! 我的主页:optimistic_chen 我的专栏:c语言 点击主页:optimistic_chen和专栏:c语言, 创作不易,大佬们点赞鼓…...
Python的创建和使用自定义模块
Python 的模块是组织代码的基本单元,它可以包含变量、函数、类等,并且可以被其他 Python 程序引用和重用。除了使用 Python 提供的标准库和第三方库外,开发者还可以创建自定义模块,用于组织和管理自己的代码。本文将详细介绍如何创…...
Python根据预设txt生成“你画我猜”题目PPT(素拓活动小工具)
Python根据预设txt生成“你画我猜”题目PPT(素拓活动小工具) 场景来源 去年单位内部的一次素拓活动,分工负责策划设置其中的“你画我猜”环节,网络上搜集到题目文字后,想着如何快速做成对应一页一页的PPT。第一时间想…...
小程序地理位置接口权限直接抄作业
小程序地理位置接口有什么功能? 随着小程序生态的发展,越来越多的小程序开发者会通过官方提供的自带接口来给用户提供便捷的服务。但是当涉及到地理位置接口时,却经常遇到申请驳回的问题,反复修改也无法通过,给的理由也…...
【Osek网络管理测试】[TG3_TC6]等待总线睡眠状态_2
🙋♂️ 【Osek网络管理测试】系列💁♂️点击跳转 文章目录 1.环境搭建2.测试目的3.测试步骤4.预期结果5.测试结果 1.环境搭建 硬件:VN1630 软件:CANoe 2.测试目的 验证DUT在满足进入等待睡眠状态的条件时是否进入该状态 …...
BEV下统一的多传感器融合框架 - FUTR3D
BEV下统一的多传感器融合框架 - FUTR3D 引言 在自动驾驶汽车或者移动机器人上,通常会配备许多种传感器,比如:光学相机、激光雷达、毫米波雷达等。由于不同传感器的数据形式不同,如RGB图像,点云等,不同模态…...
c#和python的flask接口的交互
一、灰度图像的传输 c#端的传输 //读入文件夹中的图像 Mat img2 new Mat(file, ImreadModes.AnyColor); //将图像的数据转换成和相机相同的buffer数据 byte[] image_buffer new byte[img2.Width * img2.Height]; int cn img2.Channels(); //通道数 if (cn 1){//将图像的数…...
Python测试框架Pytest的参数化详解
上篇博文介绍过,Pytest是目前比较成熟功能齐全的测试框架,使用率肯定也不断攀升。 在实际工作中,许多测试用例都是类似的重复,一个个写最后代码会显得很冗余。这里,我们来了解一下pytest.mark.parametrize装饰器&…...
KernelSU 如何不通过模块,直接修改系统分区
刚刚看了术哥发的视频,发现kernelSU通过挂载OverlayFS实现无需模块,即可直接修改系统分区,很是方便,并且安全性也很高,于是便有了这篇文章。 下面的教程与原视频存在差异,建议观看原视频后再结合本文章进行操作。 在未进行修改前,我们打开/system/文件夹,并在里面创建…...
红日靶场ATTCK 1通关攻略
环境 拓扑图 VM1 web服务器 win7(192.168.22.129,10.10.10.140) VM2 win2003(10.10.10.135) VM3 DC win2008(10.10.10.138) 环境搭建 win7: 设置内网两张网卡,开启…...
CellMarker | 人骨骼肌组织细胞Marker大全!~(强烈建议火速收藏!)
1写在前面 分享一下最近看到的2篇paper关于骨骼肌组织的细胞Marker,绝对的Atlas级好东西。👍 希望做单细胞的小伙伴觉得有用哦。😏 2常用marker(一) general_mrkrs <- c( MYH7, TNNT1, TNNT3, MYH1, MYH2, "C…...
游戏名台词大赏
文章目录 原神(圈内) 崩坏:星穹铁道(圈内) 崩坏3(圈内) 原神 只要不失去你的崇高,整个世界都会为你敞开。 总会有地上的生灵,敢于直面雷霆的威光。 谁也没有见过风&…...
OpenCV如何在图像中寻找轮廓(60)
返回:OpenCV系列文章目录(持续更新中......) 上一篇:OpenCV如何模板匹配(59) 下一篇 :OpenCV检测凸包(61) 目标 在本教程中,您将学习如何: 使用 OpenCV 函数 cv::findContours使用 OpenCV 函数 cv::d rawContours …...
java 泛型题目讲解
泛型的知识点 泛型仅存在于编译时期,编译期间JAVA将会使用Object类型代替泛型类型,在运行时期不存在泛型;且所有泛型实例共享一个泛型类 public class Main{public static void main(String[] args){ArrayList<String> list1new Arra…...
pptx 文件版面分析-- python-pptx(python 文档解析提取)
安装 pip install python-pptx -i https://pypi.tuna.tsinghua.edu.cn/simple --ignore-installedpptx 解析代码实现 from pptx import Presentation file_name "rag_pptx/test1.pptx" # 打开.pptx文件 ppt Presentation(file_name) for slide in ppt.slides:#pr…...
http的basic 认证方式
写在前面 本文看下http的basic auth认证方式。 1:什么是basic auth认证 basic auth是一种http协议规范中的一种认证方式,即一种证明你就是你的方式。更进一步的它是一种规范,这种规范是这样子,如果是服务端使用了basic auth认证…...
【信息系统项目管理师练习题】信息系统治理
IT治理的核心是关注以下哪项内容? a) 人员培训和发展计划 b) IT定位和信息化建设与数字化转型的责权利划分 c) 业务流程的绩效管理 d) IT基础设施的优化利用 答案: b) IT定位和信息化建设与数字化转型的责权利划分 IT治理体系框架的组成部分包括以下哪些? a) IT战略目标、IT治…...
RabbitMQ之顺序消费
什么是顺序消费 例如:业务上产生者发送三条消息, 分别是对同一条数据的增加、修改、删除操作, 如果没有保证顺序消费,执行顺序可能变成删除、修改、增加,这就乱了。 如何保证顺序性 一般我们讨论如何保证消息的顺序性&…...
(十)学生端搭建
本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...
微服务商城-商品微服务
数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...
在QWebEngineView上实现鼠标、触摸等事件捕获的解决方案
这个问题我看其他博主也写了,要么要会员、要么写的乱七八糟。这里我整理一下,把问题说清楚并且给出代码,拿去用就行,照着葫芦画瓢。 问题 在继承QWebEngineView后,重写mousePressEvent或event函数无法捕获鼠标按下事…...
Java求职者面试指南:计算机基础与源码原理深度解析
Java求职者面试指南:计算机基础与源码原理深度解析 第一轮提问:基础概念问题 1. 请解释什么是进程和线程的区别? 面试官:进程是程序的一次执行过程,是系统进行资源分配和调度的基本单位;而线程是进程中的…...
JS手写代码篇----使用Promise封装AJAX请求
15、使用Promise封装AJAX请求 promise就有reject和resolve了,就不必写成功和失败的回调函数了 const BASEURL ./手写ajax/test.jsonfunction promiseAjax() {return new Promise((resolve, reject) > {const xhr new XMLHttpRequest();xhr.open("get&quo…...
WEB3全栈开发——面试专业技能点P7前端与链上集成
一、Next.js技术栈 ✅ 概念介绍 Next.js 是一个基于 React 的 服务端渲染(SSR)与静态网站生成(SSG) 框架,由 Vercel 开发。它简化了构建生产级 React 应用的过程,并内置了很多特性: ✅ 文件系…...
【51单片机】4. 模块化编程与LCD1602Debug
1. 什么是模块化编程 传统编程会将所有函数放在main.c中,如果使用的模块多,一个文件内会有很多代码,不利于组织和管理 模块化编程则是将各个模块的代码放在不同的.c文件里,在.h文件里提供外部可调用函数声明,其他.c文…...
echarts使用graphic强行给图增加一个边框(边框根据自己的图形大小设置)- 适用于无法使用dom的样式
pdf-lib https://blog.csdn.net/Shi_haoliu/article/details/148157624?spm1001.2014.3001.5501 为了完成在pdf中导出echarts图,如果边框加在dom上面,pdf-lib导出svg的时候并不会导出边框,所以只能在echarts图上面加边框 grid的边框是在图里…...
初探用uniapp写微信小程序遇到的问题及解决(vue3+ts)
零、关于开发思路 (一)拿到工作任务,先理清楚需求 1.逻辑部分 不放过原型里说的每一句话,有疑惑的部分该问产品/测试/之前的开发就问 2.页面部分(含国际化) 整体看过需要开发页面的原型后,分类一下哪些组件/样式可以复用,直接提取出来使用 (时间充分的前提下,不…...
C#中用于控制自定义特性(Attribute)
我们来详细解释一下 [AttributeUsage(AttributeTargets.Class, AllowMultiple false, Inherited false)] 这个 C# 属性。 在 C# 中,Attribute(特性)是一种用于向程序元素(如类、方法、属性等)添加元数据的机制。Attr…...
