当前位置: 首页 > article >正文

DAY43打卡

@浙大疏锦行

kaggle找到一个图像数据cnn网络进行训练并且grad-cam可视化

进阶并拆分成多个文件

fruit_cnn_project/
├─ data/                # 存放数据集(需手动创建,后续放入图片)
│  ├─ train/            # 训练集图像
│  └─ val/              # 验证集图像
├─ models/              # 模型定义
│  └─ cnn_model.py      # CNN网络结构
├─ utils/               # 工具函数
│  ├─ dataset_utils.py  # 数据加载与预处理
│  ├─ grad_cam.py       # Grad-CAM可视化
│  └─ train_utils.py    # 训练与评估
├─ main.py              # 主程序
└─ requirements.txt     # 依赖列表(可选)
# 第一部分:导入库
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline# 第二部分:数据加载与预处理
def load_data():data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_dataset = datasets.ImageFolder(root='data/train', transform=data_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)test_dataset = datasets.ImageFolder(root='data/test', transform=data_transform)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)return train_loader, test_loader# 第三部分:模型定义
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)self.fc1 = nn.Linear(32 * 56 * 56, 128)self.fc2 = nn.Linear(128, 2)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 32 * 56 * 56)x = self.relu(self.fc1(x))x = self.fc2(x)return x# 第四部分:模型训练
train_loader, _ = load_data()
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')torch.save(model.state_dict(), 'trained_model.pth')# 第五部分:模型测试
_, test_loader = load_data()
model = SimpleCNN()
model.load_state_dict(torch.load('trained_model.pth'))
model.eval()
correct = 0
total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the test images: {100 * correct / total}%')# 第六部分:Grad-CAM可视化(修复版)
def get_activation():activation = {}def hook(model, input, output):activation['target_layer'] = output.detach()return hook, activationdef grad_cam(model, image, target_class_index):hook, activation = get_activation()target_layer = model.conv2target_layer.register_forward_hook(hook)model.eval()image = image.unsqueeze(0)image.requires_grad_(True)output = model(image)one_hot = torch.zeros(1, output.size()[-1]).to(image.device)one_hot[0][target_class_index] = 1output.backward(gradient=one_hot, retain_graph=True)gradients = image.grad[0].cpu().numpy()# 从activation字典中获取激活图activation_map = activation['target_layer'].cpu().numpy()[0]weights = np.mean(gradients, axis=(1, 2))cam = np.zeros(activation_map.shape[1:], dtype=np.float32)for i, w in enumerate(weights):cam += w * activation_map[i]cam = np.maximum(cam, 0)cam = F.interpolate(torch.from_numpy(cam).unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)[0][0].numpy()cam = (cam - cam.min()) / (cam.max() - cam.min())return cam# 可视化前几张测试图片
dataiter = iter(test_loader)
images, labels = dataiter.next()for i in range(5):  # 可视化前5张图片image = images[i]label = labels[i].item()cam = grad_cam(model, image, label)plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.imshow(image.permute(1, 2, 0).numpy())plt.title(f'Original Image (Class: {label})')plt.axis('off')plt.subplot(1, 2, 2)plt.imshow(image.permute(1, 2, 0).numpy())plt.imshow(cam, cmap='jet', alpha=0.5)plt.title('Grad-CAM Visualization')plt.axis('off')plt.tight_layout()plt.show()

相关文章:

DAY43打卡

浙大疏锦行 kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化 进阶:并拆分成多个文件 fruit_cnn_project/ ├─ data/ # 存放数据集(需手动创建,后续放入图片) │ ├─ train/ …...

Leetcode 1892. 页面推荐Ⅱ

1.题目基本信息 1.1.题目描述 表: Friendship ---------------------- | Column Name | Type | ---------------------- | user1_id | int | | user2_id | int | ---------------------- (user1_id,user2_id) 是 Friendship 表的主键(具有唯一值的列的组合…...

进程——环境变量及程序地址空间

目录 环境变量 概念 补充:命令行参数 引入 其它环境变量 理解 程序地址空间 引入 理解 虚拟地址存在意义 环境变量 概念 环境变量一般是指在操作系统中用来指定操作系统运行环境的一些参数。打个比方,就像你布置房间,这些参数就类…...

(4-point Likert scale)4 点李克特量表是什么

文章目录 4-point Likert scale 定义4-point Likert scale 的构成4-point Likert scale 的特点4-point Likert scale 的应用场景 4-point Likert scale 定义 4-point Likert scale(4 点李克特量表)是一种常用的心理测量量表,由美国社会心理学…...

亚矩阵云手机实测体验:稳定流畅背后的技术逻辑​

最近在测试一款云手机服务时,发现亚矩阵的表现出乎意料地稳定。作为一个经常需要多设备协作的开发者,我对云手机的性能、延迟和稳定性要求比较高。经过一段时间的体验,分享一下真实感受,避免大家踩坑。 ​​1. 云手机能解决什么问…...

VR视频制作有哪些流程?

VR视频制作流程知识 VR视频制作,作为融合了创意与技术的复杂制作过程,涵盖从初步策划到最终呈现的多个环节。在这个过程中,我们可以结合众趣科技的产品,解析每一环节的实现与优化,揭示背后的奥秘。 VR视频制作有哪些…...

NodeJS全栈WEB3面试题——P2智能合约与 Solidity

2.1 简述 Solidity 的数据类型、作用域、函数修饰符。 数据类型: 值类型(Value Types):uint, int, bool, address, bytes1 到 bytes32, enum 引用类型(Reference Types):array, struct, mappin…...

某水表量每15分钟一报,然后某天示数清0了,重新报示值了 ,如何写sql 计算每日水量

要计算每日电量,需处理电表清零的情况。以下是针对不同数据库的解决方案: 方法思路 识别清零点:通过比较当前值与前一个值,若当前值明显变小(如小于前值的10%),则视为清零。分段累计&#xff…...

Ubuntu 系统部署 MySQL 入门篇

一、安装 MySQL 1.1 更新软件包 在终端中执行以下命令,更新系统软件包列表,确保安装的是最新版本的软件: sudo apt update 1.2 安装 MySQL 执行以下命令安装 MySQL 服务端: sudo apt install mysql-server 在安装过程中&…...

【MATLAB代码】制导——平行接近法,三维,目标是运动的,订阅专栏后可直接查看MATLAB源代码

文章目录 运行结果简介代码功能概述运行结果核心模块解析代码特性与优势MATLAB例程代码调整说明相关公式视线角速率约束相对运动学方程导引律加速度指令运动学更新方程拦截条件判定运行结果 运行演示视频: 三维平行接近法导引运行演示 简介 代码功能概述 本代码实现了三维空…...

大模型安全测试报告:千问、GPT 全系列、豆包、Claude 表现优异,DeepSeek、Grok-3 与 Kimi 存在安全隐患

大模型安全测试报告:千问、GPT 全系列、豆包、Claude 表现优异,DeepSeek、Grok-3 与 Kimi 存在安全隐患 引言 随着生成式人工智能技术的快速演进,大语言模型(LLM)正在广泛应用于企业服务、政务系统、教育平台、金融风…...

vue3 按钮级别权限控制

在Vue 3中实现按钮级别的权限控制,可以通过多种方式实现。这里我将介绍几种常见的方法: 方法1:使用Vue 3的Composition API 在Vue 3中,你可以使用Composition API来创建一个可复用的逻辑来处理权限控制。 创建权限控制逻辑 首…...

vue3子组件获取并修改父组件的值

在子组件中,父组件传递来的 prop 是只读的,但是确实有修改的需求,故此做个小小研究 // 父组件使用模版:update:xxx"dialogVisible $event" // 子组件使用模版 // const emits defineEmits([update:xxx]); // emits(u…...

【Redis】Cluster集群

目录 1、背景2、核心特性【1】数据分片【2】高可用【3】去中心化【4】客户端重定向 3、集群架构【1】最小规模【2】节点角色【3】通信协议 4、数据分片与路由【1】哈希槽分配【2】客户端路由逻辑 5、故障恢复6、适用场景 1、背景 Redis Cluster是Redis官方提供的分布式解决方案…...

黑马Java面试笔记之 微服务篇(SpringCloud)

一. SpringCloud 5大组件 SpringCloud 5大组件有哪些? 总结 五大件分别有: Eureka:注册中心Ribbon:负载均衡Feign:远程调用Hystrix:服务熔断Zuul/Gateway:网关 如果项目用到了阿里巴巴&#xff…...

CLIP多模态大模型的优势及其在边缘计算中的应用

CLIP多模态大模型的优势及其在边缘计算中的应用 CLIP(Contrastive Language-Image Pre-training)模型,是OpenAI开发的一种多模态大模型。该模型通过对比学习的方式,在大规模图像-文本对上进行预训练,成功实现了图像和文…...

基于STM32语音识别柔光台灯

基于STM32语音识别柔光台灯 (程序+原理图+PCB+设计报告) 功能介绍 具体功能: 基于语音识别的智能LED柔光台灯设计,主要包括语音识别模块应用,PWM波控制LED柔光灯的亮度&#xff0c…...

基于PSO粒子群优化的VMD-GRU时间序列预测算法matlab仿真

目录 1.前言 2.算法运行效果图预览 3.算法运行软件版本 4.部分核心程序 5.算法仿真参数 6.算法理论概述 6.1变分模态分解(VMD) 6.2 门控循环单元(GRU) 6.3 粒子群优化(PSO) 7.参考文献 8.算法完…...

探索未知惊喜,盲盒抽卡机小程序系统开发新启航

在消费市场不断追求新鲜感与惊喜体验的当下,盲盒抽卡机以其独特的魅力,迅速成为众多消费者热衷的娱乐与消费方式。我们紧跟这一潮流趋势,专注于盲盒抽卡机小程序系统的开发,致力于为商家和用户打造一个充满趣味与惊喜的数字化平台…...

基于开源AI大模型与AI智能名片的S2B2C商城小程序源码优化:企业成本管理与获客留存的新范式

摘要:本文以企业成本管理的两大核心——外部成本与内部成本为切入点,结合开源AI大模型、AI智能名片及S2B2C商城小程序源码技术,构建了企业数字化转型的“技术-成本-运营”三维模型。研究结果表明,通过AI智能名片实现获客留存效率提…...

Python----目标检测(YOLO简介)

一、 YOLO简介 [YOLO](You Only Look Once)是一种流行的物体检测和图像分割模型, 由华盛顿大学的约瑟夫-雷德蒙(Joseph Redmon)和阿里-法哈迪(Ali Farhadi)开发,YOLO 于 2015 年推出&#xff0c…...

mysql+keepalived

文章目录 一、master1创建目录写入配置文件启动master1创建 `slave` 用户并授权获取主节点当前 `binary log` 文件名和位置position二、master2创建目录写入配置文件启动master2创建 `slave` 用户并授权获取主节点当前 `binary log` 文件名和位置position三、配置主主复制Maste…...

Profinet 协议 IO-Link 主站网关(三格电子)

一、产品概述 1.1 产品用途 SG-PN-IOL-8A-001 网关是 Profinet 从转 IO-Link 主的网关设备 ,可以将 IO-Link 从站设备接入 Profinet 系统,通过该网关可实现传感器及驱动器与控制 器之间的信息交互。网关有两个百兆网口和 8 个 IO-Link 端口,两…...

Ubuntu22.04 安装 Miniconda3

Conda 是一个开源的包管理系统和环境管理系统,可用于 Python 环境管理。 Miniconda 是一个轻量级的 Conda 发行版。Miniconda 包含了 Conda、Python和一些基本包,是 Anaconda 的精简版本。 1.下载安装脚本 在 conda官网 找到需要的安装版本&#xff0…...

Hubstudio浏览器如何使用Loongproxy?

1. 使用软件 1.1 Loongproxy 1. 顶级ISP资源:Loongproxy是神龙云旗下品牌,依托与全球领先ISP运营商的深度合作,Loongproxy 精选全球优质静态住宅IP资源。 2. IP池庞大:覆盖 100 国家/地区,构建庞大的 70 万 静态IP池…...

硬件工程师笔记——555定时器应用Multisim电路仿真实验汇总

目录 一 555定时器基础知识 二、引脚功能 三、工作模式 1. 单稳态模式: 2. 双稳态模式(需要外部电路辅助): 3. 无稳态模式(多谐振荡器): 4. 可控脉冲宽度调制(PWM)模式: 四、典型应用 五、优点 二 555无稳态触发器 三 555单稳态触发器 四 555双稳态触发器…...

ComfyUI 对图片进行放大的不同方法

本篇里 ComfyUI Wiki将讲解 ComfyUI 中几种基础的放大图片的办法,我们时常会因为设备性能问题,不能一次性生成大尺寸的图片,通常会先生成小尺寸的图像然后再进行放大。 不同的放大图片方法有不同的特点,以下是本篇教程将会涉及的方法: 像素重新采样SD 二次采样放大使用放…...

Elasticsearch最新入门教程

文章目录 Elasticsearch最新入门教程1.Elasticsearch安装2.Kibana安装3.Elasticsearch关键概念4.SpringBoot整合Elasticsearch4.1 导入Elasticsearch数据4.2 创建SpringBoot项目4.3 修改pom.xml文件4.4 创建es实体类4.5 创建es的查询接口 5.DSL语句5.1 无条件查询5.2 指定返回的…...

第16节 Node.js 文件系统

Node.js 提供一组类似 UNIX(POSIX)标准的文件操作API。 Node 导入文件系统模块(fs)语法如下所示: var fs require("fs") 异步和同步 Node.js 文件系统(fs 模块)模块中的方法均有异步和同步版本&#xff…...

【Linux网络篇】:从HTTP到HTTPS协议---加密原理升级与安全机制的全面解析

✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:Linux篇–CSDN博客 文章目录 HTTPS协议原理一.预备知识1.什么是“加密”2.为什么要“加密”…...