山东大学软件学院ai导论实验之生成对抗网络
目录
实验目的
实验代码
实验内容
实验结果
实验目的
基于Pytorch搭建一个生成对抗网络,使用MNIST数据集。
实验代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os# 设置环境变量
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"# 创建保存生成图像的文件夹
output_path = r"xxxxxxxxxxxxxxxxxx"
os.makedirs(output_path, exist_ok=True)# 生成器网络
class Generator(nn.Module):def __init__(self, latent_dim):super(Generator, self).__init__()self.network = nn.Sequential(nn.Linear(latent_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 784),nn.Tanh())def forward(self, z):img = self.network(z)return img.view(img.size(0), 1, 28, 28)# 判别器网络
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.network = nn.Sequential(nn.Linear(784, 512),nn.LeakyReLU(0.2),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):return self.network(img.view(img.size(0), -1))def generate_and_save_images(generator, test_input, epoch, img_path):with torch.no_grad():generated_images = generator(test_input).cpu().numpy()fig, axes = plt.subplots(4, 4, figsize=(4, 4))for i, ax in enumerate(axes.flat):# 将图像从形状 (1, 28, 28) 转换为 (28, 28),去除通道维度ax.imshow(np.squeeze(generated_images[i]), cmap='gray')ax.axis('off')img_filename = os.path.join(img_path, f"generated_epoch_{epoch}.png")plt.tight_layout()plt.savefig(img_filename)plt.close()# 设置设备(使用GPU或CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 超参数
lr = 0.0001
batch_size = 128
latent_dim = 100
epochs = 2000# 数据预处理和加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = datasets.MNIST(root='./MNIST_data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)# 测试数据:随机噪声作为输入
test_data = torch.randn(batch_size, latent_dim).to(device)# 初始化生成器和判别器,并定义损失函数和优化器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)# 记录损失
D_losses = []
G_losses = []# 训练过程
for epoch in range(epochs):for i, (imgs, _) in enumerate(train_loader):real_imgs = imgs.to(device)batch_size = real_imgs.size(0)# 判别器训练z = torch.randn(batch_size, latent_dim).to(device)fake_imgs = generator(z)real_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# 计算损失real_loss = adversarial_loss(discriminator(real_imgs), real_labels)fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)d_loss = (real_loss + fake_loss) / 2optimizer_D.zero_grad()d_loss.backward()optimizer_D.step()# 生成器训练z = torch.randn(batch_size, latent_dim).to(device)fake_imgs = generator(z)g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)optimizer_G.zero_grad()g_loss.backward()optimizer_G.step()# 记录损失D_losses.append(d_loss.item())G_losses.append(g_loss.item())# 打印每2000个步骤的迭代信息if (epoch * len(train_loader) + i) % 2000 == 0:print(f"Iter: {epoch * len(train_loader) + i}")print(f"D_loss: {d_loss.item():.4f}")print(f"G_loss: {g_loss.item():.4f}")# 每个epoch保存生成的图像generate_and_save_images(generator, test_data, epoch, output_path)# 保存生成器和判别器的模型torch.save(generator.state_dict(), "Generator_mnist.pth")torch.save(discriminator.state_dict(), "Discriminator_mnist.pth")# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(D_losses, label='Discriminator Loss')
plt.plot(G_losses, label='Generator Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.savefig('loss_curve.png') # 保存图像
plt.show() # 显示图像
实验内容
1. 数据集加载
与前几次实验一样,本实验仍然使用MNIST数据集作为输入数据集通过torchvision库进行加载并标准化处理,使得图像像素值在[-1, 1]范围内,以适应生成对抗网络的训练要求。

2. 生成器与判别器网络
生成器:生成器网络的任务是生成伪造的图像,以欺骗判别器。输入是一个随机噪声向量(latent vector),输出是一个28x28像素的图像。生成器使用多个全连接层,每个层后面都跟着一个LeakyReLU激活函数,最终输出通过Tanh激活函数确保生成的图像像素值在[-1, 1]范围内。

判别器:判别器网络的任务是区分输入的图像是“真实的”还是“伪造的”。它将图像输入后,通过多个全连接层,最后输出一个介于0和1之间的值,表示图像的真实性。

3. 训练过程
判别器训练:判别器的目标是最大化其准确性,即正确分类真实和伪造的图像。在每次训练中,先计算真实图像的损失,然后计算生成图像的损失,最后将两个损失加权平均得到判别器的总损失。

生成器训练:生成器的目标是最小化判别器对其生成图像的判断错误率。即通过调整其权重,使得生成的图像越来越像真实图像,以此欺骗判别器。生成器的损失函数是判别器对生成图像的输出,标签为“真实”(即1)。
![]()
模型优化:使用Adam优化器分别优化生成器和判别器的参数。学习率为0.0001。

- 改变隐藏层数
生成器的结构由原来的4个隐藏层缩减为2个隐藏层:


5.生成图像并保存
在每个epoch结束时,使用生成器生成一些图像,并将图像保存为PNG格式文件。每个epoch的图像被保存到指定的文件夹中,以便可视化生成图像的变化。

6. 绘制损失曲线
训练过程中记录并绘制判别器和生成器的损失曲线,以便观察模型的训练进展。

实验结果
迭代得到的训练结果为:


改变隐藏层数得到的部分结果为:

刚开始生成的初始图像为:

运行一段时间后,得到的图像为:



可以明显的看到,随着迭代不断增加,数字越来越清晰,数字识别成功
损失曲线为:
初始:

慢慢的趋于平稳:

相关文章:
山东大学软件学院ai导论实验之生成对抗网络
目录 实验目的 实验代码 实验内容 实验结果 实验目的 基于Pytorch搭建一个生成对抗网络,使用MNIST数据集。 实验代码 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data…...
C++ Qt常见面试题(2):QT中的文件流(QTextStream)和数据流(QDataStream)的区别
在 Qt 中,QTextStream 和 QDataStream 是两种常用的流类,用于通过文件或其他 I/O 设备(如网络、内存)读写数据。虽然它们都可以用来操作数据,但它们的设计目标和使用场景不同。以下是它们的主要区别和适用场景的详细说明: 1. QTextStream:文本流 QTextStream 是一种专门…...
入门网络安全工程师要学习哪些内容【2025年寒假最新学习计划】
🤟 基于入门网络安全/黑客打造的:👉黑客&网络安全入门&进阶学习资源包 大家都知道网络安全行业很火,这个行业因为国家政策趋势正在大力发展,大有可为!但很多人对网络安全工程师还是不了解,不知道网…...
【论文解读】《C-Pack: Packed Resources For General Chinese Embeddings》
论文链接:https://arxiv.org/pdf/2309.07597 本论文旨在构建一套通用中文文本嵌入的完整资源包——C-Pack,解决当前中文文本嵌入研究中数据、模型、训练策略与评测基准缺失的问题。论文主要贡献体现在以下几个方面: 大规模训练数据…...
Cramér-Rao界:参数估计精度的“理论底线”
Cramr-Rao界:参数估计精度的“理论底线” 在统计学中,当我们用数据估计一个模型的参数时,总希望估计结果尽可能精确。但精度有没有一个理论上的“底线”呢?答案是有的,这就是Cramr-Rao界(Cramr-Rao Lower …...
ClickHouse 的分区、分桶和分片详解
在大数据场景下,数据的存储和查询效率至关重要。ClickHouse 作为一款高性能的列式存储数据库,提供了多种数据组织方式来优化存储和查询,其中最常见的就是 分区(Partition)、分桶(Sampling)、分片…...
【操作系统、数学】什么是排队论?如何理解排队论?排队论有什么用处?Queueing Theory?什么是 Little’s Law?
排队论(Queueing Theory)是研究系统中排队现象的数学理论,旨在分析资源分配、服务效率及等待时间等问题。它广泛应用于计算机科学、通信网络、交通规划、工业工程等领域。 【下文会通过搜集的资料,从各方面了解排队论,…...
2209. 用地毯覆盖后的最少白色砖块
2209. 用地毯覆盖后的最少白色砖块 题目链接:2209. 用地毯覆盖后的最少白色砖块 代码如下: class Solution { public:int minimumWhiteTiles(string floor, int numCarpets, int carpetLen) {vector<vector<int>>memo (numCarpets 1, vec…...
DeepSeek赋能大模型内容安全,网易易盾AIGC内容风控解决方案三大升级
在近两年由AI引发的生产力革命的背后,一场关乎数字世界秩序的攻防战正在上演:AI生成的深度伪造视频导致企业品牌声誉损失日均超千万,批量生成的侵权内容使版权纠纷量与日俱增,黑灰产利用AI技术持续发起欺诈攻击。 与此同时&#…...
(0)阿里云大模型ACP-考试回忆
这两天通过了阿里云大模型ACP考试,由于之前在网上没有找到真题,导致第一次考试没有过,后面又重新学习了一遍文档才顺利通过考试,这两次考试内容感觉考试题目90%内容是覆盖的,后面准备分享一下每一章的考题,…...
0.【深度学习YOLOV11项目实战-项目安装教程】(图文教程,超级详细)
目录 前言一、安装Pycharm(安装过Pycharm的跳过这一步)1.1 点击下述链接直接跳转到教程页面进行安装 二、安装Anaconda(安装过Anaconda的跳过这一步)2.1 点击下述链接直接跳转到教程页面进行安装 三、后续安装教程(有N…...
Docker 部署 Jenkins持续集成(CI)工具
[TOC](Docker 部署 Jenkins持续集成(CI)工具) 前言 Jenkins 是一个流行的开源自动化工具,广泛应用于持续集成(CI)和持续交付(CD)的环境中。通过 Docker 部署 Jenkins,可以简化安装和配置过程,并…...
布署elfk-准备工作
建议申请5台机器部署elfk: filebeat(每台app)--> logstash(2台keepalived)--> elasticsearch(3台)--> kibana(部署es上)采集输出 处理转发 分布式存储 展示 ELK中文社区: 搜索客,搜索人自己的社区 官方…...
uniapp 小程序如何实现大模型流式交互?前端SSE技术完整实现解析
文章目录 一、背景概述二、核心流程图解三、代码模块详解1. UTF-8解码器(处理二进制流)2. 请求控制器(核心通信模块)3. 流式请求处理器(分块接收)4. 数据解析器(处理SSE格式)5. 回调…...
微软推出Office免费版,限制诸多,只能编辑不能保存到本地
易采游戏网2月25日独家消息:微软宣布推出一款免费的Office版本,允许用户进行基础文档编辑操作,但限制颇多,其中最引人关注的是用户无法将文件保存到本地。这一举措引发了广泛讨论,业界人士对其背后的商业策略和用户体验…...
《ArkTS鸿蒙应用开发入门到实战》—新手小白学习鸿蒙的推荐工具书!
《ArkTS鸿蒙应用开发入门到实战》—新手小白学习鸿蒙的推荐工具书! 在科技日新月异的今天,鸿蒙操作系统(HarmonyOS)作为华为推出的全新操作系统,正迅速进入越来越多的智能设备,成为物联网和智能硬件领域的…...
销售成交九步思维魔方
销售成交九步思维魔方 点 一、确定需求 原则1:问题是需求的前身原则2:基于问题才做决定原则3:人只解决大的问题 二、塑造价值 USP 利益 快乐 痛苦 价值 线 三、销售准备 精神上的准备 体能上的准备 产品知识准备 彻底了解顾客背景 …...
橄榄球、棒球项目排名·棒球1号位
美国四大体育联盟按照规模、影响力等因素综合排名,通常认为是: 1. NFL(国家橄榄球联盟):成立于1920年,是北美四大职业体育运动联盟之首,也是世界上最大的职业美式橄榄球联盟。由32支球队组成&am…...
DeepSeek 提示词:高效的提示词设计
🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…...
AI硬件加速的核心:深入探讨AI加速芯片模组的设计与应用
随着人工智能应用的普及,传统的计算架构已无法满足大规模深度学习模型训练和推理的需求。为了加速计算过程并提高能效,AI加速芯片应运而生。本文将介绍AI加速芯片模组的关键技术、发展趋势以及在各类应用中的重要性。 AI加速芯片模组的定义与构成 AI加速…...
LangChain:Models、Prompts、Indexes、Memory、Chains、Agents。MaxKB
LangChain:Models、Prompts、Indexes、Memory、Chains、Agents 在LangChain框架中,Models、Prompts、Indexes、Memory、Chains、Agents是六大核心抽象概念,它们各自承担独特功能,相互协作以助力开发者基于大语言模型构建高效智能应用。 Models(模型):指代各类大语言模型…...
html中的css
css (cascading style sheets,串联样式表,也叫层叠样式表) css规范一般约定: 1.存放CSS样式文件的目录一般命名为style或css。 2.在项目初期,会把不同类别的样式放于不同的CSS文件,是为了CSS编…...
JAVA面试常见题_基础部分_Dubbo面试题(上)
Dubbo 支持哪些协议,每种协议的应用场景,优缺点? • dubbo: 单一长连接和 NIO 异步通讯,适合大并发小数据量的服务调用,以及消费者远大于提供者。传输协议 TCP,异步,Hessian 序列化…...
Binder通信协议
目录 一,整体架构 二,Binder通信协议 一,整体架构 二,Binder通信协议...
解决应用程序 0xc00000142 错误:完整修复指南
💥 0xc00000142 错误出现的场景 你是不是遇到这样的情况: 🔹 点击某个软件,突然弹出“应用程序无法正确启动(0xc00000142)” ? 🔹 明明安装了所有必要组件,软件却始终打不开? &…...
游戏引擎学习第125天
仓库:https://gitee.com/mrxiao_com/2d_game_3 回顾并为今天的内容做准备。 昨天,当我们离开时,工作队列已经完成了基本的功能。这个队列虽然简单,但它能够执行任务,并且我们已经为各种操作编写了测试。字符串也能够正常推送到队…...
[免单统计]
免单统计 真题目录: 点击去查看 E 卷 100分题型 题目描述 华为商城举办了一个促销活动,如果某顾客是某一秒内最早时刻下单的顾客(可能是多个人),则可以获取免单。 请你编程计算有多少顾客可以获取免单。 输入描述 输入为 n 行数据,每一行表示一位顾客的下单时间 以(…...
DeepSeek R1满血+火山引擎详细教程
DeepSeek R1满血火山引擎详细教程 一、安装Cherry Studio。 Cherry Studio AI 是一款强大的多模型 AI 助手,支持 iOS、macOS 和 Windows 平台。可以快速切换多个先进的 LLM 模型,提升工作学习效率。下载地址 https://cherry-ai.com/ 认准官网,无强制注册。 这…...
前端依赖nrm镜像管理工具
npm 默认镜像 :https://registry.npmjs.org/ 1、安装 nrm npm install nrm --global2、查看镜像源列表 nrm ls3、测试当前环境下,哪个镜像源速度最快。 nrm test4、 切换镜像源 npm config get registry # 查看当前镜像源 nrm use taobao # 等价于 npm…...
【前端】Axios AJAX Fetch
不定期更新,建议关注收藏点赞。 目录 AxiosAJAXCORS 允许跨域请求 Fetch Axios axios 是一个基于 Promise 的 JavaScript HTTP 客户端,用于浏览器和 Node.js 中发送 HTTP 请求。它提供了一个简单的 API 来发起请求,并处理请求的结果。axios …...
