当前位置: 首页 > news >正文

第G9周:ACGAN理论与实战

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

上一周已经给出代码,需要可以跳转上一周的任务
第G8周:ACGAN任务

import argparse
import os
import numpy as npimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch# 创建用于存储生成图像的目录
os.makedirs("images", exist_ok=True)# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="训练的总轮数")
parser.add_argument("--batch_size", type=int, default=64, help="每个批次的大小")
parser.add_argument("--lr", type=float, default=0.0002, help="Adam优化器的学习率")
parser.add_argument("--b1", type=float, default=0.5, help="Adam优化器的一阶动量衰减")
parser.add_argument("--b2", type=float, default=0.999, help="Adam优化器的二阶动量衰减")
parser.add_argument("--n_cpu", type=int, default=4, help="用于批次生成的CPU线程数")
parser.add_argument("--latent_dim", type=int, default=100, help="潜在空间的维度")
parser.add_argument("--n_classes", type=int, default=10, help="数据集的类别数")
parser.add_argument("--img_size", type=int, default=32, help="每个图像的尺寸")
parser.add_argument("--channels", type=int, default=1, help="图像通道数")
parser.add_argument("--sample_interval", type=int, default=400, help="图像采样间隔")
opt = parser.parse_args()
print(opt)# 检查是否支持GPU加速
cuda = True if torch.cuda.is_available() else False# 初始化神经网络权重的函数
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)# 生成器网络类
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 为类别标签创建嵌入层self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)# 计算上采样前的初始大小self.init_size = opt.img_size // 4  # Initial size before upsampling# 第一层线性层self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))# 卷积层块self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise, labels):# 将标签嵌入到噪声中gen_input = torch.mul(self.label_emb(labels), noise)# 通过第一层线性层out = self.l1(gen_input)# 重新整形为合适的形状out = out.view(out.shape[0], 128, self.init_size, self.init_size)# 通过卷积层块生成图像img = self.conv_blocks(out)return img# 判别器网络类
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义判别器块的函数def discriminator_block(in_filters, out_filters, bn=True):"""返回每个判别器块的层"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return block# 判别器的卷积层块self.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样后图像的高度和宽度ds_size = opt.img_size // 2 ** 4# 输出层self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label# 损失函数
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()auxiliary_loss.cuda()# 初始化权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 配置数据加载器
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor# 保存生成图像的函数
def sample_image(n_row, batches_done):"""保存从0到n_classes的生成数字的图像网格"""# 采样噪声z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))# 为n行生成标签从0到n_classeslabels = np.array([num for _ in range(n_row) for num in range(n_row)])labels = Variable(LongTensor(labels))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)# ----------
# 训练
# ----------for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 真实数据的标签valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)# 生成数据的标签fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)# 配置输入real_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------# 训练生成器# -----------------optimizer_G.zero_grad()# 采样噪声和标签作为生成器的输入z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))# 生成一批图像gen_imgs = generator(z, gen_labels)# 损失度量生成器的欺骗判别器的能力validity, pred_label = discriminator(gen_imgs)g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))g_loss.backward()optimizer_G.step()# ---------------------# 训练判别器# ---------------------optimizer_D.zero_grad()# 真实图像的损失real_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# 生成图像的损失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# 判别器的总损失d_loss = (d_real_loss + d_fake_loss) / 2# 计算判别器的准确率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:sample_image(n_row=10, batches_done=batches_done)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

相关文章:

第G9周:ACGAN理论与实战

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子 上一周已经给出代码,需要可以跳转上一周的任务 第G8周:ACGAN任…...

Linux网络部分——DNS域名解析服务

目录 1. 域名结构 2. 系统根据域名查找IP地址的过程 3.DNS域名解析方式 4.DNS域名解析的工作原理【☆】 5.域名解析查询方式 6.搭建主从DNS域名服务器 ①初始化操作主服务器和从服务器,安装BIND软件 ②修改主服务器的主配置文件、区域配置文件、区域数…...

预处理详解

乐观学习,乐观生活,才能不断前进啊!!! 我的主页:optimistic_chen 我的专栏:c语言 点击主页:optimistic_chen和专栏:c语言, 创作不易,大佬们点赞鼓…...

Python的创建和使用自定义模块

Python 的模块是组织代码的基本单元,它可以包含变量、函数、类等,并且可以被其他 Python 程序引用和重用。除了使用 Python 提供的标准库和第三方库外,开发者还可以创建自定义模块,用于组织和管理自己的代码。本文将详细介绍如何创…...

Python根据预设txt生成“你画我猜”题目PPT(素拓活动小工具)

Python根据预设txt生成“你画我猜”题目PPT(素拓活动小工具) 场景来源 去年单位内部的一次素拓活动,分工负责策划设置其中的“你画我猜”环节,网络上搜集到题目文字后,想着如何快速做成对应一页一页的PPT。第一时间想…...

小程序地理位置接口权限直接抄作业

小程序地理位置接口有什么功能? 随着小程序生态的发展,越来越多的小程序开发者会通过官方提供的自带接口来给用户提供便捷的服务。但是当涉及到地理位置接口时,却经常遇到申请驳回的问题,反复修改也无法通过,给的理由也…...

【Osek网络管理测试】[TG3_TC6]等待总线睡眠状态_2

🙋‍♂️ 【Osek网络管理测试】系列💁‍♂️点击跳转 文章目录 1.环境搭建2.测试目的3.测试步骤4.预期结果5.测试结果 1.环境搭建 硬件:VN1630 软件:CANoe 2.测试目的 验证DUT在满足进入等待睡眠状态的条件时是否进入该状态 …...

BEV下统一的多传感器融合框架 - FUTR3D

BEV下统一的多传感器融合框架 - FUTR3D 引言 在自动驾驶汽车或者移动机器人上,通常会配备许多种传感器,比如:光学相机、激光雷达、毫米波雷达等。由于不同传感器的数据形式不同,如RGB图像,点云等,不同模态…...

c#和python的flask接口的交互

一、灰度图像的传输 c#端的传输 //读入文件夹中的图像 Mat img2 new Mat(file, ImreadModes.AnyColor); //将图像的数据转换成和相机相同的buffer数据 byte[] image_buffer new byte[img2.Width * img2.Height]; int cn img2.Channels(); //通道数 if (cn 1){//将图像的数…...

Python测试框架Pytest的参数化详解

上篇博文介绍过,Pytest是目前比较成熟功能齐全的测试框架,使用率肯定也不断攀升。 在实际工作中,许多测试用例都是类似的重复,一个个写最后代码会显得很冗余。这里,我们来了解一下pytest.mark.parametrize装饰器&…...

KernelSU 如何不通过模块,直接修改系统分区

刚刚看了术哥发的视频,发现kernelSU通过挂载OverlayFS实现无需模块,即可直接修改系统分区,很是方便,并且安全性也很高,于是便有了这篇文章。 下面的教程与原视频存在差异,建议观看原视频后再结合本文章进行操作。 在未进行修改前,我们打开/system/文件夹,并在里面创建…...

红日靶场ATTCK 1通关攻略

环境 拓扑图 VM1 web服务器 win7(192.168.22.129,10.10.10.140) VM2 win2003(10.10.10.135) VM3 DC win2008(10.10.10.138) 环境搭建 win7: 设置内网两张网卡,开启…...

CellMarker | 人骨骼肌组织细胞Marker大全!~(强烈建议火速收藏!)

1写在前面 分享一下最近看到的2篇paper关于骨骼肌组织的细胞Marker&#xff0c;绝对的Atlas级好东西。&#x1f44d; 希望做单细胞的小伙伴觉得有用哦。&#x1f60f; 2常用marker&#xff08;一&#xff09; general_mrkrs <- c( MYH7, TNNT1, TNNT3, MYH1, MYH2, "C…...

游戏名台词大赏

文章目录 原神&#xff08;圈内&#xff09; 崩坏&#xff1a;星穹铁道&#xff08;圈内&#xff09; 崩坏3&#xff08;圈内&#xff09; 原神 只要不失去你的崇高&#xff0c;整个世界都会为你敞开。 总会有地上的生灵&#xff0c;敢于直面雷霆的威光。 谁也没有见过风&…...

OpenCV如何在图像中寻找轮廓(60)

返回:OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;OpenCV如何模板匹配(59) 下一篇 :OpenCV检测凸包(61) 目标 在本教程中&#xff0c;您将学习如何&#xff1a; 使用 OpenCV 函数 cv::findContours使用 OpenCV 函数 cv::d rawContours …...

java 泛型题目讲解

泛型的知识点 泛型仅存在于编译时期&#xff0c;编译期间JAVA将会使用Object类型代替泛型类型&#xff0c;在运行时期不存在泛型&#xff1b;且所有泛型实例共享一个泛型类 public class Main{public static void main(String[] args){ArrayList<String> list1new Arra…...

pptx 文件版面分析-- python-pptx(python 文档解析提取)

安装 pip install python-pptx -i https://pypi.tuna.tsinghua.edu.cn/simple --ignore-installedpptx 解析代码实现 from pptx import Presentation file_name "rag_pptx/test1.pptx" # 打开.pptx文件 ppt Presentation(file_name) for slide in ppt.slides:#pr…...

http的basic 认证方式

写在前面 本文看下http的basic auth认证方式。 1&#xff1a;什么是basic auth认证 basic auth是一种http协议规范中的一种认证方式&#xff0c;即一种证明你就是你的方式。更进一步的它是一种规范&#xff0c;这种规范是这样子&#xff0c;如果是服务端使用了basic auth认证…...

【信息系统项目管理师练习题】信息系统治理

IT治理的核心是关注以下哪项内容? a) 人员培训和发展计划 b) IT定位和信息化建设与数字化转型的责权利划分 c) 业务流程的绩效管理 d) IT基础设施的优化利用 答案: b) IT定位和信息化建设与数字化转型的责权利划分 IT治理体系框架的组成部分包括以下哪些? a) IT战略目标、IT治…...

RabbitMQ之顺序消费

什么是顺序消费 例如&#xff1a;业务上产生者发送三条消息&#xff0c; 分别是对同一条数据的增加、修改、删除操作&#xff0c; 如果没有保证顺序消费&#xff0c;执行顺序可能变成删除、修改、增加&#xff0c;这就乱了。 如何保证顺序性 一般我们讨论如何保证消息的顺序性&…...

国密SM9在微服务网关中TPS骤降42%的真实案例,从ASN.1编码冗余到ZKP预计算的7步性能修复清单

第一章&#xff1a;SM9国密算法在微服务网关中的性能瓶颈全景图 SM9作为我国自主设计的基于身份的密码算法&#xff08;IBC&#xff09;&#xff0c;其双线性对运算、私钥生成与密文解封等核心操作天然引入显著计算开销。当部署于高并发、低延迟要求的微服务网关&#xff08;如…...

从“变速齿轮”到“创新引擎”:解码阿里“大中台、小前台”战略的演进与实战

1. 中台战略的起源与本质 第一次听说"大中台、小前台"这个概念时&#xff0c;我正坐在杭州一家咖啡馆里和几位阿里P8的技术专家聊天。他们用了一个特别形象的比喻&#xff1a;"现在的互联网公司就像一辆老式自行车&#xff0c;前台是拼命蹬车的双腿&#xff0c;…...

FlowState Lab少样本学习效果:仅用10条数据生成特定波动模式

FlowState Lab少样本学习效果&#xff1a;仅用10条数据生成特定波动模式 1. 引言&#xff1a;当数据稀缺遇上智能生成 想象一下这样的场景&#xff1a;你手里只有10条设备振动波形数据&#xff0c;却需要分析上千种可能的故障模式。传统方法可能需要收集数月甚至数年的运行数…...

包装器简介

可调用对象&#xff1a;可以使用&#xff08;&#xff09;运算符进行调用的对象&#xff0c;本质是能像函数一样使用的东西常见课调用对象&#xff1a;函数指针&#xff0c;仿函数&#xff0c;lambda表达式我们能否使用统一的方式对其封装&#xff0c;进行调用&#xff0c;这时…...

为什么你的Ping总是丢包?这7个隐藏原因90%的人都忽略了(含Wireshark分析技巧)

为什么你的Ping总是丢包&#xff1f;这7个隐藏原因90%的人都忽略了&#xff08;含Wireshark分析技巧&#xff09; 在网络运维的日常工作中&#xff0c;Ping命令就像网络工程师的听诊器&#xff0c;简单却至关重要。但当你发现Ping测试频繁丢包时&#xff0c;问题往往不像表面看…...

OpenClaw多通道管理:GLM-4.7-Flash同时对接飞书与钉钉的配置技巧

OpenClaw多通道管理&#xff1a;GLM-4.7-Flash同时对接飞书与钉钉的配置技巧 1. 为什么需要多通道管理&#xff1f; 上周我接到一个技术咨询需求&#xff1a;一个小型内容团队需要同时在飞书和钉钉两个平台上接收AI助手服务。他们的编辑用飞书&#xff0c;运营用钉钉&#xf…...

OpenClaw负载测试:GLM-4.7-Flash并发处理能力评估

OpenClaw负载测试&#xff1a;GLM-4.7-Flash并发处理能力评估 1. 测试背景与目标 上周在尝试用OpenClaw自动化处理一批市场调研报告时&#xff0c;遇到了一个典型问题&#xff1a;当我同时提交20份PDF文件让AI助手提取关键数据时&#xff0c;系统开始出现响应延迟和部分任务超…...

ESP32嵌入式系统设计与实现指南

1. 项目概述1.1 系统架构本项目基于ESP32主控芯片设计&#xff0c;采用模块化架构实现多功能嵌入式系统。系统包含以下核心模块&#xff1a;主控单元&#xff1a;ESP32-WROOM-32D模组电源管理&#xff1a;TPS63020升降压转换器传感器接口&#xff1a;I2C/SPI多协议兼容设计人机…...

OpenClaw+GLM-4.7-Flash:个人博客内容自动生成与发布

OpenClawGLM-4.7-Flash&#xff1a;个人博客内容自动生成与发布 1. 为什么选择这个技术组合 去年夏天&#xff0c;我发现自己陷入了写作瓶颈——每周要产出3篇技术博客&#xff0c;但80%的时间都消耗在资料收集和格式调整上。直到发现OpenClawGLM-4.7-Flash这个组合&#xff…...

AI性能测试:TPS之外还要关注什么?

在AI驱动的时代&#xff0c;性能测试已成为软件测试从业者的核心技能。传统软件测试中&#xff0c;TPS&#xff08;Transactions Per Second&#xff0c;每秒事务处理量&#xff09;常被视为黄金指标&#xff0c;用于衡量系统的吞吐能力。然而&#xff0c;AI系统因其独特的计算…...