Python打卡训练营学习记录Day43
作业:
kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化
进阶:并拆分成多个文件
从谷歌图片中拍摄的 10 种不同类别的动物图片
数据预处理
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_splitdef load_data(data_dir, batch_size):# 数据预处理data_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 加载数据集image_dataset = datasets.ImageFolder(data_dir, data_transform)# 划分训练集和验证集train_size = int(0.8 * len(image_dataset))val_size = len(image_dataset) - train_sizetrain_dataset, val_dataset = random_split(image_dataset, [train_size, val_size])train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)dataloaders = {'train': train_dataloader, 'val': val_dataloader}dataset_sizes = {'train': train_size, 'val': val_size}class_names = image_dataset.classesreturn dataloaders, dataset_sizes, class_names
构建并训练 CNN 模型
import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self, num_classes):super(SimpleCNN, self).__init__()# 定义特征提取层self.features = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 32, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))# 定义分类层self.classifier = nn.Sequential(nn.Linear(64 * 28 * 28, 512),nn.ReLU(inplace=True),nn.Linear(512, num_classes))def forward(self, x):# 前向传播,先通过特征提取层,再通过分类层x = self.features(x)x = x.view(-1, 64 * 28 * 28)x = self.classifier(x)return x
模型训练模块
import torch
import torch.nn as nn
import torch.optim as optimdef train_model(model, dataloaders, dataset_sizes, criterion, optimizer, num_epochs=25):# 判断是否有可用的 GPU,若有则使用 GPU 进行训练device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model.to(device)for epoch in range(num_epochs):print(f'第 {epoch} 个 epoch,共 {num_epochs - 1} 个 epochs')print('-' * 10)# 每个 epoch 都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train() # 训练模式else:model.eval() # 评估模式running_loss = 0.0running_corrects = 0# 迭代数据for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 零参数梯度optimizer.zero_grad()# 前向传播# 只有在训练时才跟踪历史with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 只有在训练阶段才进行反向传播和优化if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print(f'{phase} 阶段:损失值: {epoch_loss:.4f} 准确率: {epoch_acc:.4f}')return model
Grad-CAM可视化模块
import torch
import torch.nn.functional as F
import numpy as np
import cv2class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 反向传播钩子函数,用于捕获梯度def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0]# 前向传播钩子函数,用于捕获激活值def forward_hook(module, input, output):self.activations = outputtarget_layer.register_forward_hook(forward_hook)target_layer.register_backward_hook(backward_hook)def forward(self, input_tensor):# 将模型设置为评估模式并进行前向传播self.model.eval()output = self.model(input_tensor)return outputdef generate_cam(self, input_tensor, target_class=None):# 进行前向传播output = self.forward(input_tensor)# 如果未指定目标类别,则选择输出概率最大的类别if target_class is None:target_class = torch.argmax(output, dim=1).item()one_hot = torch.zeros_like(output)one_hot[:, target_class] = 1one_hot.requires_grad_(True)# 清零模型参数的梯度self.model.zero_grad()# 计算损失并进行反向传播(one_hot * output).sum().backward(retain_graph=True)gradients = self.gradients[0]activations = self.activations[0]# 对梯度进行全局平均池化pooled_gradients = torch.mean(gradients, dim=[1, 2])for i in range(activations.shape[0]):activations[i, :, :] *= pooled_gradients[i]# 对激活值求和生成 CAM 图cam = torch.sum(activations, dim=0).detach().cpu().numpy()# 取 CAM 图的正值部分cam = np.maximum(cam, 0)# 调整 CAM 图的大小以匹配输入图像cam = cv2.resize(cam, (input_tensor.shape[3], input_tensor.shape[2]))# 归一化 CAM 图cam = cam - np.min(cam)cam = cam / np.max(cam)return cam
主程序
from data_loader import load_data
from model import SimpleCNN
from train import train_model
from grad_cam import GradCAM
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import sys
import os
# 将当前目录添加到 Python 模块搜索路径中
sys.path.append(os.path.dirname(os.path.abspath(__file__)))if __name__ == '__main__':# 加载数据,设置批次大小,你可以根据需要调整该值batch_size = 32# 修改解包操作以处理所有返回值dataloaders, dataset_sizes, class_names = load_data('raw-img', batch_size)# 获取类别数量num_classes = len(class_names)# 使用类别数量初始化模型model = SimpleCNN(num_classes)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 训练模型trained_model = train_model(model, dataloaders, dataset_sizes, criterion, optimizer, num_epochs=5)# 生成 Grad - CAM 可视化结果# 修改此处,选择实际存在的卷积层# grad_cam = GradCAM(model, target_layer=model.conv2)grad_cam = GradCAM(model, target_layer=model.features[0])img_path = 'path/to/your/image.jpg'img = Image.open(img_path).convert('RGB')cam = grad_cam(img)plt.imshow(img)plt.imshow(cam, alpha=0.5, cmap='jet')plt.axis('off')plt.savefig('grad_cam_result.jpg')plt.show()
@浙大疏锦行
相关文章:
Python打卡训练营学习记录Day43
作业: kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化 进阶:并拆分成多个文件 从谷歌图片中拍摄的 10 种不同类别的动物图片 数据预处理 import os from torchvision import datasets, transforms from torch.utils…...

【Android基础回顾】二:handler消息机制
Android 的 Handler 机制 是 Android 应用中实现线程间通信、任务调度、消息分发的核心机制之一,它基于 消息队列(MessageQueue) 消息循环(Looper) 消息处理器(Handler) 组成。 1 handler的使用…...

每日Prompt:每天上班的状态
提示词 一个穿着清朝官服的僵尸脸上贴着符纸,在电脑面前办公,房间阴暗,电脑桌面很乱,烟灰缸里面满是烟头...
.net ORM框架dapper批量插入
.NET ORM 框架 Dapper 批量插入全解析 在 .NET 开发中,与数据库交互是常见需求。Dapper 作为轻量级的 ORM(对象关系映射)库,在简化数据库交互方面表现出色。今天我们就来深入探讨 Dapper 实现批量插入的几种方法。 为什么需要批…...

C++11 右值引用:从入门到精通
文章目录 一、引言二、左值和右值(一)概念(二)区别和判断方法 三、左值引用和右值引用(一)左值引用(二)右值引用 四、移动语义(一)概念和必要性(二…...

.net 使用MQTT订阅消息
在nuGet下载M2Mqtt V4.3.0版本。(支持.net framework) 订阅主题 public void LoadMQQCData() {string enpoint "xxx.xxx.x.x";//ip地址int port 1883;//端口string user "usrname";//用户名string pwd "pwd";//密码…...
Python实现快速排序的三种经典写法及算法解析
今天想熟悉一下python的基础写法,那就从最经典的快速排序来开始吧: 1、经典分治写法(原地排序) 时间复杂度:平均O(nlogn),最坏O(n) 空间复杂度:O(logn)递归栈空间 特点:通过左右指针…...

【递归、搜索与回溯】综合练习(四)
📝前言说明: 本专栏主要记录本人递归,搜索与回溯算法的学习以及LeetCode刷题记录,按专题划分每题主要记录:(1)本人解法 本人屎山代码;(2)优质解法 优质代码…...

强化学习入门:Gym实现CartPole随机智能体
前言 最近想开一个关于强化学习专栏,因为DeepSeek-R1很火,但本人对于LLM连门都没入。因此,只是记录一些类似的读书笔记,内容不深,大多数只是一些概念的东西,数学公式也不会太多,还望读者多多指教…...

STM32:CAN总线精髓:特性、电路、帧格式与波形分析详解
声明:此博客是我的学习笔记,所看课程是江协科技的CAN总线课程,知识点都大同小异,我仅进行总结并加上了我自己的理解,所引案例也都是课程中的案例,希望对你的理解有所帮助! 知识点1【CAN总线的概…...

贝叶斯深度学习!华科大《Nat. Commun.》发表BNN重大突破!
华科大提出基于贝叶斯深度学习的超分辨率成像,成功被Nat. Commun.收录。可以说,这是贝叶斯神经网络BNN近期最值得关注的成果之一了。另外还有AAAI 2025上的Bella新框架,计算成本降低了99.7%,也非常值得研读。 显然鉴于BNN“不确定…...

【大模型LLM学习】Flash-Attention的学习记录
【大模型LLM学习】Flash-Attention的学习记录 0. 前言1. flash-attention原理简述2. 从softmax到online softmax2.1 safe-softmax2.2 3-pass safe softmax2.3 Online softmax2.4 Flash-attention2.5 Flash-attention tiling 0. 前言 Flash Attention可以节约模型训练和推理时间…...
三、元器件的选型
前言:我们确立了题目的功能后,就可以开始元器件的选型,元器件的选型关乎到我们后面代码编写的一个难易。 一、主控的选择 主控的选择很大程度上决定我们后续使用的代码编译器,比如ESP32使用的是VScode,或者Arduino&a…...
精益数据分析(95/126):Socialight的定价转型启示——B2B商业模式的价格策略与利润优化
精益数据分析(95/126):Socialight的定价转型启示——B2B商业模式的价格策略与利润优化 在创业过程中,从B2C转向B2B不仅是商业模式的转变,更是定价策略与成本结构的全面重构。今天,我们将通过Socialight的实…...
stm32_DMA
DMA 1. 概念与基本原理 DMA,全称Direct Memory Access,即直接存储器访问。它是微控制器(MCU)、嵌入式处理器中的一个独立硬件模块,用于在无需CPU干预的情况下,在不同内存区域(包括外设寄存器和…...

物联网数据归档之数据存储方案选择分析
在上一篇文章中《物联网数据归档方案选择分析》中凯哥分析了归档设计的两种方案,并对两种方案进行了对比。这篇文章咱们就来分析分析,归档后数据应该存储在哪里?及存储方案对比。 这里就选择常用的mysql及taos数据库来存储归档后的数据吧。 你在处理设备归档表存储方案时对…...
【自动驾驶避障开发】如何让障碍物在 RViz 中‘显形’?呈现感知数据转 Polygon 全流程
【自动驾驶避障开发】如何让障碍物在 RViz 中"显形"?呈现感知数据转 Polygon 全流程 自动驾驶系统中的障碍物可视化是开发调试过程中至关重要的一环。本文将详细介绍如何将自动驾驶感知模块检测到的障碍物数据转换为RViz可显示的Polygon(多边形)形式,实现障碍物…...

【C语言】C语言经典小游戏:贪吃蛇(上)
文章目录 一、游戏背景及其功能二、Win32 API介绍1、Win32 API2、控制台程序3、定位坐标(COORD)4、获得句柄(GetStdHandle)5、获得光标属性(GetConsoleCursorInfo)1)描述光标属性(CO…...
usbutils工具的使用帮助
作为嵌入式系统开发中的常用工具,usbutils 是一套用于管理和调试USB设备的Linux命令行工具集。以下是其核心功能和使用方法的详细说明: 1. 工具组成 核心命令: lsusb:列出所有连接的USB设备及详细信息(默认安装&#…...

vue2中使用jspdf插件实现页面自定义块pdf下载
pdf下载 实现pdf下载的环境安装jspdf插件在项目中使用 实现pdf下载的环境 项目需求案例背景,点击【pdf下载】按钮,弹出pdf下载弹窗,显示需要下载四个模块的下载进度,下载完成后,关闭弹窗即可! 项目使用的是…...

如何防止服务器被用于僵尸网络(Botnet)攻击 ?
防止服务器被用于僵尸网络(Botnet)攻击是关键的网络安全措施之一。僵尸网络是黑客利用大量被感染的计算机、服务器或物联网设备来发起攻击的网络。以下是关于如何防止服务器被用于僵尸网络攻击的技术文章: 防止服务器被用于僵尸网络ÿ…...

基于cornerstone3D的dicom影像浏览器 第二十九章 自定义菜单组件
文章目录 前言一、程序结构1. 菜单数据结构2. XMenu.vue3. XSubMenu.vue4. XSubMenuSlot.vue5. XMenuItem.vue 二、调用流程总结 前言 菜单用于组织程序功能,为用户提供导航。是用户与程序交互非常重要的接口。 开源组件库像Element Plus和Ant Design中都提供了功能…...

【Block总结】DBlock,结合膨胀空间注意模块(Di-SpAM)和频域模块Gated-FFN|即插即用|CVPR2025
论文信息 标题: DarkIR: Robust Low-Light Image Restoration 作者: Daniel Feijoo, Juan C. Benito, Alvaro Garcia, Marcos Conde 论文链接:https://arxiv.org/pdf/2412.13443 GitHub链接:https://github.com/cidautai/DarkIR 创新点 DarkIR提出了…...
【学习笔记】单例类模板
【学习笔记】单例类模板 一、单例类模板 以下为一个通用的单例模式框架,这种设计允许其他类通过继承Singleton模板类来轻松实现单例模式,而无需为每个类重复编写单例实现代码。 // 命名空间(Namespace) 和 模板(Tem…...
字符串加密(华为OD)
题目描述 给你一串未加密的字符串str,通过对字符串的每一个字母进行改变来实现加密,加密方式是在每一个字母str[i]偏移特定数组元素a[i]的量,数组a前三位已经赋值:a[0]=1,a[1]=2,a[2]=4。当i>=3时,数组元素a[i]=a[i-1]+a[i-2]+a[i-3]。例如:原文 abcde 加密后 bdgkr,…...

口罩佩戴检测算法AI智能分析网关V4工厂/工业等多场景守护公共卫生安全
一、引言 在公共卫生安全日益受到重视的当下,口罩佩戴成为预防病毒传播、保障人员健康的重要措施。为了高效、精准地实现对人员口罩佩戴情况的监测,AI智能分析网关V4口罩检测方案应运而生。该方案依托先进的人工智能技术与强大的硬件性能,…...

Double/Debiased Machine Learning
独立同步分布的观测数据 { W i ( Y i , D i , X i ) ∣ i ∈ { 1 , . . . , n } } \{W_i(Y_i,D_i,X_i)| i\in \{1,...,n\}\} {Wi(Yi,Di,Xi)∣i∈{1,...,n}},其中 Y i Y_i Yi表示结果变量, D i D_i Di表示因变量, X i X_i Xi表…...

HarmonyOS Next 弹窗系列教程(4)
HarmonyOS Next 弹窗系列教程(4) 介绍 本章主要介绍和用户点击关联更加密切的菜单控制(Menu) 和 气泡提示(Popup) 它们出现显示弹窗出现的位置都是在用户点击屏幕的位置相关 菜单控制(Menu&…...

【C】-递归
1、递归概念 递归(Recursion)是编程中一种重要的解决问题的方法,其核心思想是函数通过调用自身来解决规模更小的子问题,直到达到最小的、可以直接解决的基准情形(Base Case)。 核心:自己调用…...

飞马LiDAR500雷达数据预处理
0 引言 在使用飞马D2000无人机搭载LiDAR500进行作业完成后,需要对数据进行预处理,方便给内业人员开展点云分类等工作。在开始操作前,先了解一下使用的软硬件及整体流程。 0.1 外业测量设备 无人机:飞马D2000S激光模块ÿ…...