从0开始深度学习(12)——多层感知机的逐步实现
依然以Fashion-MNIST图像分类数据集为例,手动实现多层感知机和激活函数的编写,大部分代码均在从0开始深度学习(9)——softmax回归的逐步实现中实现过
1 读取数据
import torch
from torchvision import transforms
import torchvision
from torch.utils import data# 读取数据
def load_data_fashion_mnist(batch_size, resize=None): #@savetrans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="D:/DL_Data/", train=True, transform=trans, download=False)mnist_test = torchvision.datasets.FashionMNIST(root="D:/DL_Data/", train=False, transform=trans, download=False)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=12),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=12))train_iter, test_iter = load_data_fashion_mnist(256, resize=28)
2 初始化模型参数
以单隐藏层的多层感知机为例,选择使用256个隐藏单元
from torch import nn# 初始化模型参数
num_inputs=784 # 28*28
num_outputs=10
num_hiddens=256 # 我们选择使用256个隐藏单元,注意,一般选择使用2的若干次幂,因为内存的特殊性,可以在计算上更高效w1 = nn.Parameter(torch.randn(num_inputs,num_hiddens,requires_grad=True)*0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens,requires_grad=True))w2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))params = [w1, b1, w2, b2]
3 激活函数、损失函数、建立模型
# 激活函数
def relu(x):a=torch.zeros_like(x) # 保证全零张量和x的形状一致,利于广播计算return torch.max(x,a)# 损失函数
loss = nn.CrossEntropyLoss(reduction='none')#建立模型
def net(x):x=x.reshape((-1,num_inputs))#展开H=relu(x@w1+b1)# @表示矩阵乘法return (H@w2+b2)
4 训练模型
优化器使用SGD
#训练,优化器使用sgd
num_epochs=5
lr=00.1
updater=torch.optim.SGD(params,lr=lr)def train_epoch(net, train_iter, loss, updater):if isinstance(net, torch.nn.Module):net.train() # 将模型设置为训练模式metric = Accumulator(3) # 训练损失总和、训练准确度总和、样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()else:l.backward()updater([w, b], lr, batch_size)metric.add(float(l) * y.numel(), compute_accuracy(y_hat, y), y.numel())return metric[0] / metric[2], metric[1] / metric[2]def train(net, train_iter, test_iter, loss, num_epochs, updater):for epoch in range(num_epochs):train_metrics = train_epoch(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)print(f'Epoch {epoch + 1}: Train Loss {train_metrics[0]:.3f}, Train Acc {train_metrics[1]:.3f}, Test Acc {test_acc:.3f}')class Accumulator: #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def compute_accuracy(y_hat, y): # 预测值、真实值if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1) # 找到一个样本中,对应的最大概率的类别cmp = y_hat.type(y.dtype) == y # 将预测值 y_hat 与真实标签 y 进行比较,生成一个布尔张量 cmpreturn float(cmp.type(y.dtype).sum())# 计算在指定数据集上模型的准确率
def evaluate_accuracy(net, data_iter): if isinstance(net, torch.nn.Module):net.eval() # 将模型设置为评估模式metric = Accumulator(2) # 累加多个变量的总和。这里初始化了一个包含两个元素的累加器,分别用来存储正确预测的数量和总的预测数量。with torch.no_grad():for X, y in data_iter:metric.add(compute_accuracy(net(X), y), y.numel())return metric[0] / metric[1]train(net, train_iter, test_iter, loss, num_epochs, updater)
5 预测
import matplotlib.pyplot as plt
# 定义 Fashion-MNIST 标签的文本描述
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 预测并显示结果
def predict(net, test_iter, n=6):for X, y in test_iter:break # 只取一个批次的数据trues = get_fashion_mnist_labels(y)preds = get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + '\n' + pred for true, pred in zip(trues, preds)]n = min(n, X.shape[0])fig, axs = plt.subplots(1, n, figsize=(12, 3))for i in range(n):axs[i].imshow(X[i].permute(1, 2, 0).squeeze().numpy(), cmap='gray')axs[i].set_title(titles[i])axs[i].axis('off')plt.show()# 调用预测函数
predict(net, test_iter, n=6)
相关文章:

从0开始深度学习(12)——多层感知机的逐步实现
依然以Fashion-MNIST图像分类数据集为例,手动实现多层感知机和激活函数的编写,大部分代码均在从0开始深度学习(9)——softmax回归的逐步实现中实现过 1 读取数据 import torch from torchvision import transforms import torchv…...

如何利用OpenCV和yolo实现人脸检测
在之前的blog里面,我们有介绍OpenCV和yolo的区别,本文就人脸检测为例,分别介绍下OpenCV和yolo的实现方式。 OpenCV实现人脸检测 一、安装 OpenCV 首先确保你已经安装了 OpenCV 库。可以通过以下方式安装: 使用包管理工具安装&…...

015集——c# 实现CAD excel交互(CAD—C#二次开发入门)
第一步:添加引用 程序集—>扩展 namespace WindowsFormsApp2 {public partial class Form1 : Form{public Form1(){InitializeComponent();}private void Form1_Load(object sender, EventArgs e){}private void 获取当前excel_Click(object sender, EventArgs e…...

【计网笔记】以太网
经典以太网 总线拓扑 物理层 Manchester编码 数据链路层 MAC子层 MAC帧 DIX格式与IEEE802.3格式 IEEE802.3格式兼容DIX格式 前导码(帧开始定界符SOF) 8字节 前7字节均为0xAA第8字节为0xAB前7字节的Manchester编码将产生稳定方波,用于…...

Java 入门基础篇14 - java面向对象思想以及特性
学习目标: 一、目标 面向对象思想类和对象对象的创建和使用属性和方法封装 开始学习: 二、编程思想 2.1 什么是编程思想 做人有做人的原则,编程也有编程的原则。这些编程的原则,就叫做编程思想。 2.2 面向过程和面向对象 二…...

第15篇:网络架构优化与综合案例分析
目录 引言 15.1 网络性能优化的方法与工具 15.1.1 带宽管理与流量控制 15.1.2 负载均衡 15.1.3 缓存优化 15.2 网络故障的排查与解决 15.2.1 常用的网络故障排查工具 15.2.2 网络故障排查案例 15.3 网络安全架构的综合设计案例 15.3.1 企业网络安全架构的要求 15.3.…...

UI自动化测试实战
补充:Selenium主要用于Web页面的自动化测试,它可以模拟用户的各种操作,如点击、输入、滚动等,来测试网页的功能。而Appium是一个开源的移动端自动化测试工具。 一、自动化测试实战章节 自动化测试流程测试用例编写项目自动化测试…...

东方智者颜廷利:以哲学思想促进世界和谐与无私奉献
【本社讯】在全球化的今天,东方智慧与哲学思想正逐渐成为促进世界和谐与理解的重要力量。近日,祖籍齐鲁大地山东济南的东方智者颜廷利以其深邃的哲学思想和对人类社会的深刻洞察,引起了国际社会的广泛关注。 颜廷利,一位致力于哲学研究与实践的智者,他的思想跨越古今,融合了东…...

基于 springboot vue停车场管理系统 设计与实现
博主介绍:专注于Java(springboot ssm 等开发框架) vue .net php phython node.js uniapp 微信小程序 等诸多技术领域和毕业项目实战、企业信息化系统建设,从业十五余年开发设计教学工作 ☆☆☆ 精彩专栏推荐订阅☆☆☆☆☆不…...

如何验证ssl私钥和证书是否匹配?
从证书(CRT)文件提取公钥 openssl x509 -in server.crt -pubkey -noout | openssl sha256从证书签名请求(CSR)文件提取公钥 openssl req -in server.csr -pubkey -noout | openssl sha256从私钥(KEY)文件…...

MongoDB的基本操作
🌷数据库准备 🎈Mongoshell 1.在指定目录下创建mongodb文件夹、其子文件log和data以及mongodb.log cd /home/ubuntu mkdir -p mongodb/data mkdir -p mongodb/log touch mongodb/log/mongodb.log 执行mongodb命令启动mongdb服务 mongod --dbpath /h…...

spring mvc后端实现过程
文章目录 一、Spring mvc1、controller1.1、LoginController011.2、LoginController 2、service2.1、LoginService2.1、LoginInimplements 3、dao3.1、LoginMapper3.1、LoginMapper.xml 4、实体类 一、Spring mvc 1、controller 控制器层、处理用户的请求和响应, …...

102005
import os os.environ["CUDA_VISIBLE_DEVICES"] "0" # 设定使用的 GPUimport tensorflow as tf from dataset import generate_data import numpy as np from model import enhancednet# 检查 TensorFlow 是否可以识别 GPU gpus tf.config.list_physica…...

Cisco ACI环境给Leaf配置OOB带外管理IP方法
可以通过GUI 或CLI进行配置 通过CLI更简单,和配置传统交换机差不多, ACI中共有3大组件 APIC 控制器 SPINE 核心 LEAF 接入 下面我们将3种角色的带外IP配置方法都列出来 1 APIC配置带外IP This example shows how to configure out-of-band managemen…...

免费送源码:Java+B/S+MySQL springboot电影推荐系统 计算机毕业设计原创定制
摘 要 随着互联网与移动互联网迅速普及,网络上的电影娱乐信息数量相当庞大,人们对获取感兴趣的电影娱乐信息的需求越来越大,个性化的电影推荐系统成为一个热门。然而电影信息的表示相当复杂,己有的相似度计算方法与推荐算法都各有优势&#…...

数据清洗(脚本)
使用脚本清洗数据时,可以根据具体的数据问题选择编程语言,如Shell、Python、SQL等。这里我以 Python(Pandas库) 和 SQL 为例,演示如何通过脚本进行数据清洗。 1. 使用 Python(Pandas库) 进行数…...

jmeter中发送post请求遇到的问题
用jmeter发送post请求,把请求参数放在Body Data处,参数都写得正确,但没想到结果每次都报错,直接响应结果乱七八糟,改成用Parameters,反而不乱报错了。 上图 请求里如下 另外一些请求也是这样 这个响应结果也是错误的…...

Java中使用protobuf
一、简介 Protocal Buffers(简称protobuf)是谷歌的一项技术,用于结构化的数据序列化、反序列化。 Protocol Buffers 是一种语言无关、平台无关、可扩展的序列化结构数据的方法,它可用于(数据)通信协议、数据存储等。 Protocol B…...

2020款Macbook Pro A2251无法充电无法开机定位及修复
问题背景 up主有一台2020年的Macbook Pro,带Touch Bar,16G512G,四核I5,型号A2251 应该是一周没充电了,之前还用的好好的,后来有一天出差想带上 打开没电,手头上有个小米的66W快充头,…...

Spring Cloud --- 引入Gateway网关
引入Gateway网关 介绍 Spring Cloud Gateway 组件的核心是一系列的过滤器,通过这些过滤器可以将客户端发送的请求转发(路由)到对应的微服务。 Spring Cloud Gateway 是加在整个微服务最前沿的防火墙和代理器,隐藏微服务结点 IP 端口信息,从…...

ESP32-C3实现定时器的启停(Arduino IDE)
1概述 ESP32-C3微控制器有多个定时器,这些定时器可用于各种用途,包括计时、生成PWM信号、测量输入信号的频率等。以下是ESP32-C3上可用的定时器资源: 两个硬件定时器: 定时器0:这是一个通用定时器,通常用于…...

centos升级g++使其支持c++17
centos升级g使其支持c17 升级g的原因现象原因 升级g方法更新镜像源yum升级g版本 总结 升级g的原因 现象 编译最新版本的jsoncpp报一下错误 jsontest.h:87:37: error: ‘hexfloat’ is not a member of ‘std’oss << std::setprecision(16) << std::hexfloat &l…...

Pytest日志收集器配置
前言 在pytest框架中,日志记录(logging)是一个强大的功能,它允许我们在测试期间记录信息、警告、错误等,从而帮助调试和监控测试进度。 pytest与Python标准库中的logging模块完美集成,因此你可以很容易地在…...

Morris算法(大数据作业)
我只能说,概率证明真的好难啊!(;′⌒) 这也证明我的概率论真的学的很差劲,有时间一定要补补/(ㄒoㄒ)/~~ 算法不难证明难! 当一个数足够大时,能不能用更少的空间来近似表示这个整数n,于是&…...

TCP/IP协议 【三次握手】过程简要描述
当建立TCP连接时,三次握手的作用简要描述如下: 第一次握手(客户端向服务器发送SYN包):客户端发送SYN包给服务器,确认服务器是否在线并等待响应。 第二次握手(服务器向客户端发送SYNACK包&…...

docker 数据管理,数据持久化详解 二 数据卷容器
数据卷和数据卷容器核心区别 持久性对比 数据卷:当您直接在启动容器时指定了一个数据卷(例如,使用docker run -v /data),这个数据卷会自动创建,并且其内容会在容器停止或删除后继续存在。您可以随时通过Do…...

Logrotate:Linux系统日志轮转和管理的实用指南
Logrotate是Linux系统中用于自动化管理日志文件的强大工具,它能够高效、安全地轮转、压缩和清理日志文件,从而有效控制日志文件大小,节省磁盘空间,并显著提升系统可维护性和安全性。本文档将提供Logrotate的实用指南,涵…...

八股面试3(自用)
基本数据类型和引用数据类型区别 java中数据类型分为基本数据类型和引用数据类型 8大基本数据类型 1.整数:int,long,short,byte 2.浮点类型:float,double 3.字符类型:char 4.布尔类型&…...

【微服务】springboot3 集成 Flink CDC 1.17 实现mysql数据同步
目录 一、前言 二、常用的数据同步解决方案 2.1 为什么需要数据同步 2.2 常用的数据同步方案 2.2.1 Debezium 2.2.2 DataX 2.2.3 Canal 2.2.4 Sqoop 2.2.5 Kettle 2.2.6 Flink CDC 三、Flink CDC介绍 3.1 Flink CDC 概述 3.1.1 Flink CDC 工作原理 3.2 Flink CDC…...

【Android】浅析OkHttp(1)
【Android】浅析OkHttp(1) OkHttp 是一个高效、轻量级的 HTTP 客户端库,主要用于 Android 和 Java 应用开发。它不仅支持同步和异步的 HTTP 请求,还支持许多高级功能,如连接池、透明的 GZIP 压缩、响应缓存、WebSocke…...