当前位置: 首页 > news >正文

第G9周:ACGAN理论与实战

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

上一周已经给出代码,需要可以跳转上一周的任务
第G8周:ACGAN任务

import argparse
import os
import numpy as npimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch# 创建用于存储生成图像的目录
os.makedirs("images", exist_ok=True)# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="训练的总轮数")
parser.add_argument("--batch_size", type=int, default=64, help="每个批次的大小")
parser.add_argument("--lr", type=float, default=0.0002, help="Adam优化器的学习率")
parser.add_argument("--b1", type=float, default=0.5, help="Adam优化器的一阶动量衰减")
parser.add_argument("--b2", type=float, default=0.999, help="Adam优化器的二阶动量衰减")
parser.add_argument("--n_cpu", type=int, default=4, help="用于批次生成的CPU线程数")
parser.add_argument("--latent_dim", type=int, default=100, help="潜在空间的维度")
parser.add_argument("--n_classes", type=int, default=10, help="数据集的类别数")
parser.add_argument("--img_size", type=int, default=32, help="每个图像的尺寸")
parser.add_argument("--channels", type=int, default=1, help="图像通道数")
parser.add_argument("--sample_interval", type=int, default=400, help="图像采样间隔")
opt = parser.parse_args()
print(opt)# 检查是否支持GPU加速
cuda = True if torch.cuda.is_available() else False# 初始化神经网络权重的函数
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)# 生成器网络类
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 为类别标签创建嵌入层self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)# 计算上采样前的初始大小self.init_size = opt.img_size // 4  # Initial size before upsampling# 第一层线性层self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))# 卷积层块self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise, labels):# 将标签嵌入到噪声中gen_input = torch.mul(self.label_emb(labels), noise)# 通过第一层线性层out = self.l1(gen_input)# 重新整形为合适的形状out = out.view(out.shape[0], 128, self.init_size, self.init_size)# 通过卷积层块生成图像img = self.conv_blocks(out)return img# 判别器网络类
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义判别器块的函数def discriminator_block(in_filters, out_filters, bn=True):"""返回每个判别器块的层"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return block# 判别器的卷积层块self.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样后图像的高度和宽度ds_size = opt.img_size // 2 ** 4# 输出层self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label# 损失函数
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()auxiliary_loss.cuda()# 初始化权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 配置数据加载器
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor# 保存生成图像的函数
def sample_image(n_row, batches_done):"""保存从0到n_classes的生成数字的图像网格"""# 采样噪声z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))# 为n行生成标签从0到n_classeslabels = np.array([num for _ in range(n_row) for num in range(n_row)])labels = Variable(LongTensor(labels))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)# ----------
# 训练
# ----------for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 真实数据的标签valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)# 生成数据的标签fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)# 配置输入real_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------# 训练生成器# -----------------optimizer_G.zero_grad()# 采样噪声和标签作为生成器的输入z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))# 生成一批图像gen_imgs = generator(z, gen_labels)# 损失度量生成器的欺骗判别器的能力validity, pred_label = discriminator(gen_imgs)g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))g_loss.backward()optimizer_G.step()# ---------------------# 训练判别器# ---------------------optimizer_D.zero_grad()# 真实图像的损失real_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# 生成图像的损失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# 判别器的总损失d_loss = (d_real_loss + d_fake_loss) / 2# 计算判别器的准确率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:sample_image(n_row=10, batches_done=batches_done)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

相关文章:

第G9周:ACGAN理论与实战

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子 上一周已经给出代码,需要可以跳转上一周的任务 第G8周:ACGAN任…...

Linux网络部分——DNS域名解析服务

目录 1. 域名结构 2. 系统根据域名查找IP地址的过程 3.DNS域名解析方式 4.DNS域名解析的工作原理【☆】 5.域名解析查询方式 6.搭建主从DNS域名服务器 ①初始化操作主服务器和从服务器,安装BIND软件 ②修改主服务器的主配置文件、区域配置文件、区域数…...

预处理详解

乐观学习,乐观生活,才能不断前进啊!!! 我的主页:optimistic_chen 我的专栏:c语言 点击主页:optimistic_chen和专栏:c语言, 创作不易,大佬们点赞鼓…...

Python的创建和使用自定义模块

Python 的模块是组织代码的基本单元,它可以包含变量、函数、类等,并且可以被其他 Python 程序引用和重用。除了使用 Python 提供的标准库和第三方库外,开发者还可以创建自定义模块,用于组织和管理自己的代码。本文将详细介绍如何创…...

Python根据预设txt生成“你画我猜”题目PPT(素拓活动小工具)

Python根据预设txt生成“你画我猜”题目PPT(素拓活动小工具) 场景来源 去年单位内部的一次素拓活动,分工负责策划设置其中的“你画我猜”环节,网络上搜集到题目文字后,想着如何快速做成对应一页一页的PPT。第一时间想…...

小程序地理位置接口权限直接抄作业

小程序地理位置接口有什么功能? 随着小程序生态的发展,越来越多的小程序开发者会通过官方提供的自带接口来给用户提供便捷的服务。但是当涉及到地理位置接口时,却经常遇到申请驳回的问题,反复修改也无法通过,给的理由也…...

【Osek网络管理测试】[TG3_TC6]等待总线睡眠状态_2

🙋‍♂️ 【Osek网络管理测试】系列💁‍♂️点击跳转 文章目录 1.环境搭建2.测试目的3.测试步骤4.预期结果5.测试结果 1.环境搭建 硬件:VN1630 软件:CANoe 2.测试目的 验证DUT在满足进入等待睡眠状态的条件时是否进入该状态 …...

BEV下统一的多传感器融合框架 - FUTR3D

BEV下统一的多传感器融合框架 - FUTR3D 引言 在自动驾驶汽车或者移动机器人上,通常会配备许多种传感器,比如:光学相机、激光雷达、毫米波雷达等。由于不同传感器的数据形式不同,如RGB图像,点云等,不同模态…...

c#和python的flask接口的交互

一、灰度图像的传输 c#端的传输 //读入文件夹中的图像 Mat img2 new Mat(file, ImreadModes.AnyColor); //将图像的数据转换成和相机相同的buffer数据 byte[] image_buffer new byte[img2.Width * img2.Height]; int cn img2.Channels(); //通道数 if (cn 1){//将图像的数…...

Python测试框架Pytest的参数化详解

上篇博文介绍过,Pytest是目前比较成熟功能齐全的测试框架,使用率肯定也不断攀升。 在实际工作中,许多测试用例都是类似的重复,一个个写最后代码会显得很冗余。这里,我们来了解一下pytest.mark.parametrize装饰器&…...

KernelSU 如何不通过模块,直接修改系统分区

刚刚看了术哥发的视频,发现kernelSU通过挂载OverlayFS实现无需模块,即可直接修改系统分区,很是方便,并且安全性也很高,于是便有了这篇文章。 下面的教程与原视频存在差异,建议观看原视频后再结合本文章进行操作。 在未进行修改前,我们打开/system/文件夹,并在里面创建…...

红日靶场ATTCK 1通关攻略

环境 拓扑图 VM1 web服务器 win7(192.168.22.129,10.10.10.140) VM2 win2003(10.10.10.135) VM3 DC win2008(10.10.10.138) 环境搭建 win7: 设置内网两张网卡,开启…...

CellMarker | 人骨骼肌组织细胞Marker大全!~(强烈建议火速收藏!)

1写在前面 分享一下最近看到的2篇paper关于骨骼肌组织的细胞Marker&#xff0c;绝对的Atlas级好东西。&#x1f44d; 希望做单细胞的小伙伴觉得有用哦。&#x1f60f; 2常用marker&#xff08;一&#xff09; general_mrkrs <- c( MYH7, TNNT1, TNNT3, MYH1, MYH2, "C…...

游戏名台词大赏

文章目录 原神&#xff08;圈内&#xff09; 崩坏&#xff1a;星穹铁道&#xff08;圈内&#xff09; 崩坏3&#xff08;圈内&#xff09; 原神 只要不失去你的崇高&#xff0c;整个世界都会为你敞开。 总会有地上的生灵&#xff0c;敢于直面雷霆的威光。 谁也没有见过风&…...

OpenCV如何在图像中寻找轮廓(60)

返回:OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;OpenCV如何模板匹配(59) 下一篇 :OpenCV检测凸包(61) 目标 在本教程中&#xff0c;您将学习如何&#xff1a; 使用 OpenCV 函数 cv::findContours使用 OpenCV 函数 cv::d rawContours …...

java 泛型题目讲解

泛型的知识点 泛型仅存在于编译时期&#xff0c;编译期间JAVA将会使用Object类型代替泛型类型&#xff0c;在运行时期不存在泛型&#xff1b;且所有泛型实例共享一个泛型类 public class Main{public static void main(String[] args){ArrayList<String> list1new Arra…...

pptx 文件版面分析-- python-pptx(python 文档解析提取)

安装 pip install python-pptx -i https://pypi.tuna.tsinghua.edu.cn/simple --ignore-installedpptx 解析代码实现 from pptx import Presentation file_name "rag_pptx/test1.pptx" # 打开.pptx文件 ppt Presentation(file_name) for slide in ppt.slides:#pr…...

http的basic 认证方式

写在前面 本文看下http的basic auth认证方式。 1&#xff1a;什么是basic auth认证 basic auth是一种http协议规范中的一种认证方式&#xff0c;即一种证明你就是你的方式。更进一步的它是一种规范&#xff0c;这种规范是这样子&#xff0c;如果是服务端使用了basic auth认证…...

【信息系统项目管理师练习题】信息系统治理

IT治理的核心是关注以下哪项内容? a) 人员培训和发展计划 b) IT定位和信息化建设与数字化转型的责权利划分 c) 业务流程的绩效管理 d) IT基础设施的优化利用 答案: b) IT定位和信息化建设与数字化转型的责权利划分 IT治理体系框架的组成部分包括以下哪些? a) IT战略目标、IT治…...

RabbitMQ之顺序消费

什么是顺序消费 例如&#xff1a;业务上产生者发送三条消息&#xff0c; 分别是对同一条数据的增加、修改、删除操作&#xff0c; 如果没有保证顺序消费&#xff0c;执行顺序可能变成删除、修改、增加&#xff0c;这就乱了。 如何保证顺序性 一般我们讨论如何保证消息的顺序性&…...

网络编程(Modbus进阶)

思维导图 Modbus RTU&#xff08;先学一点理论&#xff09; 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议&#xff0c;由 Modicon 公司&#xff08;现施耐德电气&#xff09;于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架&#xff0c;它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用&#xff0c;和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

Android Wi-Fi 连接失败日志分析

1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分&#xff1a; 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析&#xff1a; CTR…...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个&#xff1f;3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制&#xff08;过半机制&#xff0…...

将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?

Otsu 是一种自动阈值化方法&#xff0c;用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理&#xff0c;能够自动确定一个阈值&#xff0c;将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

C++ 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

算法:模拟

1.替换所有的问号 1576. 替换所有的问号 - 力扣&#xff08;LeetCode&#xff09; ​遍历字符串​&#xff1a;通过外层循环逐一检查每个字符。​遇到 ? 时处理​&#xff1a; 内层循环遍历小写字母&#xff08;a 到 z&#xff09;。对每个字母检查是否满足&#xff1a; ​与…...

人工智能(大型语言模型 LLMs)对不同学科的影响以及由此产生的新学习方式

今天是关于AI如何在教学中增强学生的学习体验&#xff0c;我把重要信息标红了。人文学科的价值被低估了 ⬇️ 转型与必要性 人工智能正在深刻地改变教育&#xff0c;这并非炒作&#xff0c;而是已经发生的巨大变革。教育机构和教育者不能忽视它&#xff0c;试图简单地禁止学生使…...

LangChain知识库管理后端接口:数据库操作详解—— 构建本地知识库系统的基础《二》

这段 Python 代码是一个完整的 知识库数据库操作模块&#xff0c;用于对本地知识库系统中的知识库进行增删改查&#xff08;CRUD&#xff09;操作。它基于 SQLAlchemy ORM 框架 和一个自定义的装饰器 with_session 实现数据库会话管理。 &#x1f4d8; 一、整体功能概述 该模块…...

深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用

文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么&#xff1f;1.1.2 感知机的工作原理 1.2 感知机的简单应用&#xff1a;基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...