模型解释与可解释AI实战
一、为什么需要模型解释?
模型解释技术帮助:
- 理解模型决策依据(特征重要性)
- 调试模型错误预测
- 满足监管合规要求(金融/医疗)
- 提升用户对AI的信任
本章使用Captum实现CV/NLP模型的可视化解释
二、环境准备与工具安装
!pip install captum torchvision matplotlib
import torch
import numpy as np
from captum.attr import IntegratedGradients, LayerGradCam
import matplotlib.pyplot as plt
三、图像分类解释实战(CIFAR-10)
1. 加载预训练模型
from torchvision.models import resnet18model = resnet18(pretrained=True)
model.eval()
2. 准备测试图像
from torchvision import transformstransform = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# 加载示例图像(类别:ship)
from PIL import Image
img = Image.open("test_ship.jpg").convert("RGB")
input_tensor = transform(img).unsqueeze(0)
3. 集成梯度解释
def visualize_attr(attr, title):attr = attr.squeeze().cpu().detach().numpy()plt.imshow(attr, cmap='hot')plt.colorbar()plt.title(title)plt.show()# 计算特征重要性
integrated_grad = IntegratedGradients(model)
attr_ig = integrated_grad.attribute(input_tensor, target=8) # ship类别ID为8
visualize_attr(attr_ig.mean(dim=1), "Integrated Gradients")
4. Grad-CAM可视化
# 选择目标卷积层
target_layer = model.layer4.conv2# 计算Grad-CAM
layer_gradcam = LayerGradCam(model, target_layer)
attr_gc = layer_gradcam.attribute(input_tensor, target=8)# 可视化叠加效果
heatmap = np.clip(attr_gc.squeeze().cpu().detach().numpy(), 0, None)
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())orig_img = input_tensor.squeeze().permute(1,2,0).cpu().detach().numpy()
plt.imshow(orig_img * 0.5 + heatmap * 0.5)
plt.title("Grad-CAM Visualization")
plt.show()
四、文本分类解释实战(IMDB情感分析)
1. 加载情感分析模型
from transformers import AutoTokenizer, AutoModelForSequenceClassificationtokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
2. 构建解释器
from captum.attr import LayerIntegratedGradients# 定义输入处理函数
def model_forward(input_ids, attention_mask=None):return model(input_ids, attention_mask).logits# 初始化解释器
lig = LayerIntegratedGradients(model_forward,model.bert.embeddings
)
3. 计算词元重要性
text = "This movie is a complete disaster, full of terrible acting and pointless scenes."
inputs = tokenizer(text, return_tensors="pt")# 计算基准值(空输入)
ref_input_ids = torch.tensor([tokenizer.cls_token_id] + [tokenizer.pad_token_id]*(inputs.input_ids.shape-2) + [tokenizer.sep_token_id], device='cpu').unsqueeze(0)# 计算归因值
attributions, delta = lig.attribute(inputs=inputs.input_ids,baselines=ref_input_ids,additional_forward_args=(inputs.attention_mask,),return_convergence_delta=True,target=0 # 负面情感对应的类别
)# 可视化结果
token_attributions = attributions.sum(dim=2).squeeze(0)
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids)plt.figure(figsize=(12, 3))
plt.bar(range(len(tokens)), token_attributions.detach().numpy())
plt.xticks(range(len(tokens)), tokens, rotation=90)
plt.title("Token Importance Scores")
plt.show()
五、高级解释技巧
1. 对比解释(对比不同类别)
# 对比飞机(类别0)与鸟类(类别2)的解释差异
attr_plane = integrated_grad.attribute(input_tensor, target=0)
attr_bird = integrated_grad.attribute(input_tensor, target=2)plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(attr_plane.mean(1).squeeze().detach().cpu(), cmap='hot')
plt.title('Airplane Attribution')plt.subplot(1,2,2)
plt.imshow(attr_bird.mean(1).squeeze().detach().cpu(), cmap='hot')
plt.title('Bird Attribution')
plt.show()
2. 层次相关性传播(LRP)
from captum.attr import LRPlrp = LRP(model)
attr_lrp = lrp.attribute(input_tensor, target=8)plt.imshow(attr_lrp.mean(1).squeeze().detach().cpu(), cmap='hot')
plt.title('Layer-wise Relevance Propagation')
plt.show()
六、常见问题解答
Q1:如何安装最新版Captum?
pip install git+https://github.com/pytorch/captum.git
Q2:归因结果全为0怎么办?
- 检查输入是否经过正确的归一化
- 尝试不同的基线值(Baseline)
- 验证模型是否真的使用该特征
import matplotlib
matplotlib.use('Agg') # 无GUI模式plt.ioff()
fig = plt.figure()
# ...生成图像...
fig.savefig('explanation.png', bbox_inches='tight')
plt.close(fig)
相关文章:
模型解释与可解释AI实战
一、为什么需要模型解释? 模型解释技术帮助: 理解模型决策依据(特征重要性)调试模型错误预测满足监管合规要求(金融/医疗)提升用户对AI的信任 本章使用Captum实现CV/NLP模型的可视化解释 二、环境…...
1、pytest基本用法
目录 先给大家分享下学习资源 1. 安装pytest 2. 编写用例规则 3. 执行用例 最近在学习pytest的用法 并且用这套框架替换了原来的unittest, 同是测试框架 确实感觉到pytest更加便捷 这边分享给大家我得学习心得 先给大家分享下学习资源 1 官方文档 pytest 官方…...
【八股文】http怎么建立连接的
http协议的连接建立过程主要基于TCP协议,核心步骤包括TCP连接建立、HTTP协议交互 TCP连接建立 三次握手 客户端与服务器通过TCP协议建立连接,需完成三次握手: SYN包:客户端发送SYN报文,请求建立连接。SYN-ACK包&…...
人工智能AI术语
人工智能(AI)术语是理解人工智能领域的重要组成部分,涵盖了从基础概念到具体技术的广泛内容。这些术语不仅帮助我们理解AI技术的本质,还为研究者、开发者和决策者提供了重要的参考依据。通过掌握这些术语,我们可以更好…...
制作PaddleOCR/PaddleHub的Docker镜像
背景 在落地RAG知识库过程中,遇到了图文识别、图片表格内容识别的需求。但那时(2024年4月)各开源RAG项目还没有集成成熟的解决方案,经调研我选择了百度开源的PaddleOCR。支持国产! 概念梳理 PaddleOCR 百度飞桨的OCR…...
Ubuntu部署Docker搭建靶场
前言 我们需要部署Docker来搭建靶场题目,他可以提供一个隔离的环境,方便在不同的机器上部署,接下来,我会记录我的操作过程,简单的部署一道题目 Docker安装 不推荐在物理机上部署,可能会遇到一些问题&…...
【DFS】羌笛何须怨杨柳,春风不度玉门关 - 4. 二叉树中的深搜
本篇博客给大家带来的是二叉树深度优先搜索的解法技巧,在后面的文章中题目会涉及到回溯和剪枝,遇到了一并讲清楚. 🐎文章专栏: DFS 🚀若有问题 评论区见 ❤ 欢迎大家点赞 评论 收藏 分享 如果你不知道分享给谁,那就分享给薯条. 你们的支持是我不断创作的…...
制作rpm包
使用nfpm制作rpm包,下面是做包使用到的关键文件。 . |-- makefile |-- nfpm.yaml -- scripts |-- postinstall.sh |-- postremove.sh |-- preinstall.sh -- preremove.sh preinstall:在npm install命令前执行 install,postinstal…...
搭建Redis主从集群
主从集群说明 单节点Redis的并发能力是有上限的,要进一步提高Redis的并发能力,就需要搭建主从集群,实现读写分离。 主从结构 这是一个简单的Redis主从集群结构 集群中有一个master节点、两个slave节点(现在叫replica)…...
1.NextJS基础
NextJS注意要点 文件用来定义路由,folder name becomes the route name注意区分客户端渲染和服务器渲染 html渲染完成后给到客户端(此时网页内容已经全部提供),有利于crawler和优化seo逻辑更简单request deduplication减少API请求…...
【时时三省】(C语言基础)选择结构和条件判断
山不在高,有仙则名。水不在深,有龙则灵。 ----CSDN 时时三省 选择结构和条件判断 在现实生活中需要进行判断和选择的情况是很多的。如:从北京出发上高速公路,到一个岔路口,有两个出口,一个是去上海方向,另一个是沈阳方向。驾车者到此处必须进行判断,根据自己的目的地…...
作业12 (2023-05-15 指针概念)
第1题/共11题【单选题】 关于指针的概念,错误的是:( ) A.指针变量是用来存放地址的变量 B.指针变量中存的有效地址可以唯一指向内存中的一块区域 C.野指针也可以正常使用 D.局部指针变量不初始化就是野指针 回答正确 答案解析: A:正确,指针变量中存储的是一个地址,指…...
WSL2增加memory问题
我装的是Ubuntu24-04版本,所有的WSL2子系统默认memory为主存的一半(我的电脑是16GB,wsl是8GB),可以通过命令查看: free -h #查看ubuntu的memory和swap (改过的11GB) 前几天由于配置E…...
git 合并多次提交 commit
在工作中,有时候在反复修改代码中(比如处理MR的检视意见,或者为了推送到测试环境,先 commit到自己的远程分支上)不免会有多次 commit,这样发起 MR 的时候,就会有一堆 commit 信息,看…...
Wireshark网络抓包分析使用详解
序言 之前学计网还有前几天备考华为 ICT 网络赛道时都有了解认识 Wireshark,但一直没怎么专门去用过,也没去系统学习过,就想趁着备考的网络相关知识还没忘光,先来系统学下整理点笔记~ 什么是抓包?抓包就是将网络传输…...
【OpenGL】GLSL基础语法
GLSL(OpenGL Shading Language)是用于编写 OpenGL 着色器程序的高级编程语言,主要分为顶点着色器(Vertex Shader)、片段着色器(Fragment Shader),有时还会用到几何着色器(…...
前端实现截图功能
前端实现截图 在前端开发中,有时我们需要在网页中实现截图功能。无论是为了记录页面内容、生成报告,还是制作网页截图,掌握如何在浏览器中进行截图是非常实用的。今天,我将通过一个简单的示例,介绍如何使用 html2canv…...
如何分析和解决服务器的僵尸进程问题
### 如何分析和解决服务器的僵尸进程问题 #### **一、僵尸进程的定义与影响** **僵尸进程(Zombie Process)** 是已终止但未被父进程回收资源的进程。其特点: - **状态标识**:在进程列表(如 ps 或 top)中标…...
智能提示词生成器:助力测试工程师快速设计高质量测试用例
在软件测试中,测试用例设计方法的选择和实施是确保软件质量的重要步骤。测试工程师经常需要根据不同的测试场景、参数维度和业务需求,设计出覆盖率高且有效的测试用例。然而,设计测试用例并非易事,特别是在面对复杂的业务逻辑时。 为了帮助测试工程师高效生成测试用例提示…...
XXL-Job 二次分片是怎么做的?有什么问题?怎么去优化的?
XXL-JOB二次分片机制及优化策略 二次分片实现原理 XXL-JOB的二次分片是在分片广播策略的基础上,由开发者自行实现的更细粒度数据拆分。核心流程如下: 初次分片:调度中心根据执行器实例数量(总分片数n)分配分片索引i&…...
java版嘎嘎快充玉阳软件互联互通中电联云快充协议充电桩铁塔协议汽车单车一体充电系统源码uniapp
演示: 微信小程序:嘎嘎快充 http://server.s34.cn:1888/ 系统管理员 admin/123456 运营管理员 yyadmin/Yyadmin2024 运营商 operator/operator2024 系统特色: 多商户、汽车单车一体、互联互通、移动管理端(开发中) 另…...
SpringMVC 配置详解
SpringMVC 是 Spring 框架中用于构建 Web 应用程序的模块,它基于 MVC(Model-View-Controller)设计模式,能够将业务逻辑、数据和显示分离,从而提高代码的可维护性和可扩展性。本文将详细介绍 SpringMVC 的配置步骤和相关…...
详细Linux中级知识(不断完善)
Nginx服务配置 基于主机名配置 映射IP和主机名 [rootlocalhost ~]# vim /etc/hosts 192.168.72.135 www.chengke.com chengke[rootlocalhost ~]# echo "192.168.72.135 www.xx.com" >> /etc/hosts以上是两种方法,前面是你的IP地址,后…...
Spatial Multiplexing Power Save
802.11n中添加的PSMP,SMPS机制。 SM 节能功能可让 STA 在大部分时间内仅通过一条活动接收链运行,从而达到节能目的。 空间复用省电(Spatial Multiplexing Power Save)模式下,节点会关闭多余的天线,仅仅使用一根天线进…...
2025年渗透测试面试题总结-某360-企业蓝军面试复盘 (题目+回答)
网络安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 360-企业蓝军 一、Shiro绕WAF实战方案 二、WebLogic遭遇WAF拦截后的渗透路径 三、JBoss/WebLogic反序…...
实时图像处理:让你的应用更智能
I. 引言 实时图像处理在现代应用中扮演着重要的角色,它能够使应用更加智能、响应更加迅速。本文将深入探讨实时图像处理的原理、部署过程以及未来的发展趋势,旨在帮助开发者更好地理解如何将实时图像处理应用于他们的项目中。 II. 实时图像处理的基础概…...
C语言基础—函数指针与指针函数
函数指针 定义 函数指针本质上是指针,它是函数的指针(定义了一个指针变量,变量中存储了函数的地址)。函数都有一个入口地址,所谓指向函数的指针,就是指向函数的入口地址。这里函数名就代表入口地址。 函…...
用DrissionPage升级网易云音乐爬虫:更稳定高效地获取歌单音乐(附原码)
一、传统爬虫的痛点分析 原代码使用requests re的方案存在以下局限性: 动态内容缺失:无法获取JavaScript渲染后的页面内容 维护成本高:网页结构变化需频繁调整正则表达式 反爬易触发:简单请求头伪造容易被识别 资源消耗大&am…...
OpenCV图像拼接(5)构建图像的拉普拉斯金字塔 (Laplacian Pyramid)
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 cv::detail::createLaplacePyr 是 OpenCV 中的一个函数,用于构建图像的拉普拉斯金字塔 (Laplacian Pyramid)。拉普拉斯金字塔是一种多…...
03 Python 基础:数据类型、运算符与流程控制解析
文章目录 一、数据类型 内置的六大类数字类型整数类型 int浮点数 float布尔 bool字符串 str 变量命名 二、数字类型的相互转换显式类型的转换整数,浮点数,复数 之间的显式转换 隐式类型的转换 三、标识符算术运算符比较运算符逻辑运算符位运算符赋值运算…...
