当前位置: 首页 > article >正文

Day43 Python打卡训练营

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

选取Kaggle上的CIFAR-10数据集进行CNN训练,并使用Grad-CAM进行可视化,代码将拆分为多个文件以保持模块化。CIFAR-10是一个包含60,000张32x32彩色图像的数据集,分为10个类别。

项目结构

cifar10_cnn_gradcam/
├── data_loader.py         # 数据加载和预处理
├── model.py              # CNN模型定义
├── gradcam.py            # Grad-CAM实现
├── train.py              # 模型训练逻辑
├── visualize.py          # 可视化Grad-CAM结果
├── main.py               # 主执行脚本
└── requirements.txt      # 依赖库

1. 数据加载(data_loader.py)

此文件负责加载和预处理CIFAR-10数据集,并进行训练、验证、测试集划分。

import tensorflow as tf
from sklearn.model_selection import train_test_splitdef load_cifar10_data():# 加载CIFAR-10数据集(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()# 归一化像素值到[0, 1]x_train = x_train.astype('float32') / 255.0x_test = x_test.astype('float32') / 255.0# 将训练集进一步拆分为训练和验证集(80%训练,20%验证)x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)# 类名class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']return (x_train, y_train), (x_val, y_val), (x_test, y_test), class_names

2. 模型定义 (model.py)

此文件定义一个简单的CNN模型,适合CIFAR-10分类任务。

import tensorflow as tf
from tensorflow.keras import layers, modelsdef build_cnn_model():model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3), padding='same'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu', padding='same'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu', padding='same'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dropout(0.5),layers.Dense(10, activation='softmax')])model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])return model

3. Grad-CAM实现 (gradcam.py)

此文件实现Grad-CAM算法,用于生成CNN的注意力热图。

import tensorflow as tf
import numpy as np
import cv2class GradCAM:def __init__(self, model, layer_name):self.model = modelself.layer_name = layer_nameself.grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(layer_name).output, model.output])def generate_heatmap(self, image, class_idx):image = tf.cast(image, tf.float32)with tf.GradientTape() as tape:conv_output, predictions = self.grad_model(image)loss = predictions[:, class_idx]grads = tape.gradient(loss, conv_output)pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))conv_output = conv_output[0]heatmap = tf.reduce_mean(tf.multiply(conv_output, pooled_grads), axis=-1)heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)return heatmap.numpy()def superimpose_heatmap(self, image, heatmap, alpha=0.4):heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))heatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)image = np.uint8(255 * image)superimposed_img = heatmap * alpha + imagesuperimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)return superimposed_img

4. 模型训练 (train.py)

此文件包含训练逻辑,使用数据增强以提高模型鲁棒性。

import tensorflow as tf
from tensorflow.keras import layers
from model import build_cnn_modeldef train_model(x_train, y_train, x_val, y_val, epochs=25, batch_size=32):model = build_cnn_model()# 数据增强data_augmentation = tf.keras.Sequential([layers.RandomFlip("horizontal"),layers.RandomRotation(0.1),layers.RandomZoom(0.1),])# 训练模型history = model.fit(data_augmentation(x_train), y_train,validation_data=(x_val, y_val),epochs=epochs,batch_size=batch_size,verbose=1)model.save('cifar10_cnn_model.h5')return model, history

5. 可视化Grad-CAM结果 (visualize.py)

此文件负责生成和保存Grad-CAM可视化结果。

import numpy as np
import matplotlib.pyplot as plt
from gradcam import GradCAMdef visualize_gradcam(model, x_test, y_test, class_names, num_images=5):gradcam = GradCAM(model, layer_name='conv2d_2')  # 选择最后一层卷积层plt.figure(figsize=(15, 10))for i in range(num_images):img = x_test[i:i+1]true_label = y_test[i][0]pred = model.predict(img)pred_label = np.argmax(pred, axis=1)[0]# 生成热图heatmap = gradcam.generate_heatmap(img, pred_label)superimposed_img = gradcam.superimpose_heatmap(img[0], heatmap)# 可视化plt.subplot(num_images, 3, i*3 + 1)plt.imshow(img[0])plt.title(f'True: {class_names[true_label]}')plt.axis('off')plt.subplot(num_images, 3, i*3 + 2)plt.imshow(heatmap, cmap='jet')plt.title('Heatmap')plt.axis('off')plt.subplot(num_images, 3, i*3 + 3)plt.imshow(superimposed_img)plt.title(f'Pred: {class_names[pred_label]}')plt.axis('off')plt.tight_layout()plt.savefig('gradcam_visualization.png')plt.close()

6. 主执行脚本 (main.py)

此文件协调整个流程,调用其他模块执行数据加载、训练和可视化。

from data_loader import load_cifar10_data
from train import train_model
from visualize import visualize_gradcamdef main():# 加载数据(x_train, y_train), (x_val, y_val), (x_test, y_test), class_names = load_cifar10_data()# 训练模型model, history = train_model(x_train, y_train, x_val, y_val, epochs=25, batch_size=32)# 可视化Grad-CAMvisualize_gradcam(model, x_test, y_test, class_names, num_images=5)# 评估模型test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)print(f"Test accuracy: {test_acc:.4f}")if __name__ == "__main__":main()

7. 依赖文件 (requirements.txt)

列出项目所需的Python库。

tensorflow==2.10.0 numpy scikit-learn matplotlib opencv-python

@浙大疏锦行

相关文章:

Day43 Python打卡训练营

作业: kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化 进阶:并拆分成多个文件 选取Kaggle上的CIFAR-10数据集进行CNN训练,并使用Grad-CAM进行可视化,代码将拆分为多个文件以保持模块化。CIFAR-10是…...

雷卯针对易百纳 SS524多媒体处理演示评估板防雷防静电方案

一、 应用场景 1. 远程视频会议 2. 安防监控 3. 人/车检测 4. 人脸检测、比对 5. 屏幕拼接墙 二、 功能概述 1 四核 ARM Cortex-A7 1.2GHz 2 AI算力 1.0Tops 3 4K30fps 4*1080P30编解码 三、 扩展接口 l RAM:板载 2*DDR4,共 2GB; …...

【BUG解决】关于BigDecimal与0的比较问题

这是一个很细小的知识点,但是很容易被忽略掉,导致系统问题,因此记录下来 问题背景 明明逻辑上看a和b都不为0才会调用除法,但是系统会报错:java.lang.ArithmeticException异常: if (!a.equals(BigDecimal…...

Spring Bean 为何“难产”?攻克构造器注入的依赖与歧义

本文已收录在Github,关注我,紧跟本系列专栏文章,咱们下篇再续! 🚀 魔都架构师 | 全网30W技术追随者🔧 大厂分布式系统/数据中台实战专家🏆 主导交易系统百万级流量调优 & 车联网平台架构&a…...

LeetCodeHot100(图论篇)

目录 图论岛屿数量题目代码 腐烂的橘子题目代码 课程表题目代码 实现 Trie (前缀树)题目代码 后续内容持续更新~~~ 图论 岛屿数量 题目 给你一个由 ‘1’(陆地)和 ‘0’(水)组成的的二维网格,请你计算网格中岛屿的数…...

【Lecture01】动手开发科研智能体(WIN11系统)

1. 配置win11系统中的环境,安装管理器Choco: # Download and install Chocolatey: powershell -c "irm https://community.chocolatey.org/install.ps1|iex" # Download and install Node.js: choco install nodejs-lts --version"22&qu…...

“packageManager“: “pnpm@9.6.0“ 配置如何正确启动项目?

今天在学习开源项目的时候,在安装依赖时遇到了一个报错 yarn add pnpm9.6.0 error This projects package.json defines "packageManager": "yarnpnpm9.6.0". However the current global version of Yarn is 1.22.22.Presence of the "…...

Git Github Gitee GitLab

Git的工作流程 工作区(Workspace):电脑本地目录,即平时存放项目代码的地方 暂存区(Index/Stage):临时存放改动信息的地方 本地仓库(Repository):存放所有提交的版本数据 远程仓库(Remote):托管代码的服务器&#x…...

华为设备OSPF配置与实战指南

一、基础配置架构 sysname HUAWEI-ABR ospf 100 router-id 1.1.1.1area 0.0.0.0network 10.1.1.0 0.0.0.255 # 将接口加入区域0 interface GigabitEthernet0/0/1ospf enable 100 area 0.0.0.0 # 华为支持点分十进制区域号bandwidth-reference 10000 # 设置10Gbps参考带宽…...

Paraformer分角色语音识别-中文-通用 FunASR

https://github.com/modelscope/FunASR/blob/main/README_zh.md https://github.com/modelscope/FunASR/blob/main/model_zoo/readme_zh.md PyTorch / 2.3.0 / 3.12(ubuntu22.04) / 12.1 Paraformer分角色语音识别-中文-通用 https://www.modelscope.cn/models/iic/speech_p…...

Spitfire:Codigger 生态中的高性能、安全、分布式浏览器

Spitfire 是 Codigger 生态系统中的一款现代化浏览器,专为追求高效、隐私和分布式技术的用户设计。它结合了 Codigger 的分布式架构优势,在速度、安全性和开发者支持方面提供了独特的解决方案,同时确保用户对数据的完全控制。 1. 高性能浏览…...

vimadbgit命令

vim 全部选中 全选(高亮显示):按esc后,然后ggvG或者ggVG 全部复制:按esc后,然后ggyG 全部删除:按esc后,然后dG -----------------------------------------------------------------…...

运行shell脚本时报错/bin/bash^M: 解释器错误: 没有那个文件或目录

Windows的换行符为\r\n,而linux换行符为\n。先查看一下文件是什么格式的 :set ff --查询一下格式是什么 由于使用nodepad新建的脚本,首选项中格式设置成了windows,上传到linux中报错。 解决方法 1、nodepad中【设置》首选项】修改为unix&am…...

2506,wtl的通知事件

通知事件 最后一步,通知(连接)控件CMainDlg想要接受的浏览器控件触发的消息.连接在OnInitDialog(),断开在OnDestroy(). VC6中连接 VC6中,ATL的全局函数,AtlAdviseSinkMap()通知(连接)对话框中所有控件开始或终止发送事件到C对象. 该该函数的第一个参数是一个指向拥有事件映射…...

Shiro安全权限框架

①、添加依赖 ②、创建ini文件 获取权限相关信息可以通过数据库获取,也可以通过ini配置文件获取 ③、认证代码 public class ShiroRun{public static void main(){//初始化获取SecurityManagerIniSerucityManagerFactory factory new IniSecurityManagerFac…...

虚拟现实教育终端技术方案——基于EFISH-SCB-RK3588的全场景国产化替代

一、VR教育终端技术挑战与替代价值 ‌实时交互性能瓶颈‌ 赛扬N100/N150仅支持3DOF渲染(延迟>25ms),动态手势识别帧率≤15FPS,难以满足6DOF教学场景需求RK3588 Mali-G610 GPU支持6DOF空间渲染(延迟≤12ms&…...

深入理解CSS浮动:从基础原理到实际应用

深入理解CSS浮动:从基础原理到实际应用 引言 在网页设计中,CSS浮动(float)是一个历史悠久却又至关重要的概念。虽然现代布局技术如Flexbox和Grid逐渐流行,但浮动仍然在许多场景中发挥着重要作用。本文将带你深入理解…...

代码训练LeetCode(22)研究者H指数

代码训练(22)LeetCode之研究者H指数 Author: Once Day Date: 2025年6月4日 漫漫长路,才刚刚开始… 全系列文章可参考专栏: 十年代码训练_Once-Day的博客-CSDN博客 参考文章: 274. H 指数 - 力扣(LeetCode)力扣 (LeetCode) 全球极客挚爱的…...

网络安全A模块专项练习任务五解析

任务五:Linux 操作系统安全配置-1 任务环境说明: ✓ 服务器场景:LinuxServer:(开放链接) ✓ 用户名:root,密码:123456 ✓ 数据库用户名:root,密码:123456 请对服务器 LinuxServer 按要求进行相应的设置,提高服务器的安全性。 1.设置最小…...

git cli 基于远程master分支创建本地分支并切换

1、获取远程最新状态 git fetch origin2、从远程master创建本地分支并切换 git checkout -b new-branch-name origin/master或者,新版本写法 git switch -c new-branch-name origin/master3、如果要推送到远程,并建立跟踪,执行下面的命令 …...

Redis初入门

Nosql:Not-Only SQL(泛指非关系型数据库),作为关系型数据库的补充 作用:应对基于海量用户和海量数据前提下的数据处理问题 redis:C语言开发的一个开源的高性能键值对数据库 特征: 1、数据之…...

(10)Fiddler抓包-Fiddler如何设置捕获Firefox浏览器的Https会话

1.简介 经过上一篇对Fiddler的配置后,绝大多数的Https的会话,我们可以成功捕获抓取到,但是有些版本的Firefox浏览器仍然是捕获不到其的Https会话,需要我们更进一步的配置才能捕获到会话进行抓包。 2.环境 1.环境是Windows 10版…...

使用pandas实现合并具有共同列的两个EXCEL表

表1&#xff1a; 表2&#xff1a; 表1和表2&#xff0c;有共同的列“名称”&#xff0c;而且&#xff0c;表1的内容&#xff08;行数&#xff09;<表2的行数。 目的&#xff0c;根据“名称”列的对应内容&#xff0c;将表2列中的“所处行业”填写到表1相应的位置。 实现代…...

2025年- H69-Lc177--78.子集(回溯,组合)--Java版

1.题目描述 2.思路 3.代码实现 class Solution {public List<List<Integer>> subsets(int[] nums) {List<List<Integer>> resnew ArrayList<>();List<Integer> curnew ArrayList<>();//从索引0开始递归backtracking(res,cur,nums,0…...

目标检测任务的评估指标mAP50和mAP50-95

mAP50 和 mAP50-95 是目标检测任务中常用的评估指标&#xff0c;用于衡量模型在不同 交并比&#xff08;IoU&#xff09;阈值 下的平均精度&#xff08;Average Precision, AP&#xff09;。它们的区别主要体现在 IoU 阈值范围 上。 ✅ 1. mAP50&#xff08;mean Average Prec…...

C++String的学习

1、C语言中的字符串 C语言中&#xff0c;字符串是以’\0’结尾的一些字符的集合&#xff0c;为了操作方便&#xff0c;C标准库中提供了一些str系列的库函数&#xff0c;但是这些库函数与字符串是分离开的&#xff0c;不太符合OOP的思想&#xff08;即面向对象编程&#xff08;…...

java day15 (数据库)

进入数据库的学习 DB 因为数据太多了&#xff0c;方便统一管理的软件 操作就不用改代码了&#xff0c;直接改数据库则可&#xff1b; 命令就是sql语句 这些都是关系型数据库&#xff0c;sql可以控制全部&#xff0c;至于具体的环境我以前就有安装过了&#xff1b; 理解&am…...

SQL 中 IN 和 EXISTS 的区别

SQL 中 IN 和 EXISTS 的区别 1. 基本概念 1.1 IN 运算符 IN 是一个条件运算符,用于检查某个值是否存在于指定的值列表中或子查询返回的结果集中。 SELECT * FROM employees WHERE department_id IN (SELECT id FROM departments WHERE location = New York)...

多线程爬虫使用代理IP指南

多线程爬虫能有效提高工作效率&#xff0c;如果配合代理IP爬虫效率更上一层楼。作为常年使用爬虫做项目的人来说&#xff0c;选择优质的IP池子尤为重要&#xff0c;之前我讲过如果获取免费的代理ip搭建自己IP池&#xff0c;虽然免费但是IP可用率极低。 在多线程爬虫中使用代理I…...

前端面试真题(第一集)

目录标题 1、跨域问题及解决方法同源策略生产环境解决方案开发环境解决方案其他解决方案 2、组件间通信方式Vue2中的组件通信方式Vue3中的组件通信方式通用注意事项 3、微信小程序生命周期微信小程序原生生命周期UniApp生命周期 4、微信小程序授权登录流程登录流程手机号获取 5…...