当前位置: 首页 > article >正文

最近帮实验室刚入门的师弟复现了西储大学轴承故障的迁移学习代码,本来以为是手到擒来的活,结果还是踩了好几个坑,刚好整理出来给同样摸鱼入门的小伙伴参考

一区top轴承故障诊断迁移学习代码复现 故障诊断代码 复现 首先使用一维的cnn对源域和目标域进行特征提取域适应阶段将源域和目标域作为cnn的输入得到特征然后进行边缘概率分布对齐和条件概率分布对齐也就是进行JDA联合对齐。此域适应方法特别适合初学者了解迁移学习的基础知识。 数据预处理1维数据 网络模型1D-CNN-MMD-Coral 数据集西储大学CWRU 准确率99% 网络框架pytorch 结果输出损失曲线图、准确率曲线图、混淆矩阵、tsne图 使用对象初学者 注意此代码是一个在 GPU 上跑的代码有宝子的电脑只支持 cpu只需要将代码修改成只在 cpu 上跑的就行这个项目真的太适合新手练手了用最简单的1D-CNN提取振动信号特征再用JDA做联合域对齐不光能快速跑通还能实打实看懂迁移学习到底在对齐什么东西比那些堆了一堆Transformer的花活代码友好太多。先唠唠整体思路说白了就是三步把西储大学的振动数据分成源域比如负载0的故障样本和目标域比如负载1的故障样本用1D-CNN从两类数据里提取特征用JDA把源域和目标域的特征分布拉到一起让模型在源域学的故障知识能直接用到目标域上全程用PyTorch写的GPU跑起来超快没有GPU也能改CPU版本完全符合大家的需求。第一步数据预处理西储大学的数据集是1维的振动信号我一般会存成npy格式方便加载不用每次都解matlab文件。这里写个自定义的Dataset类新手直接抄就能用import os import torch import numpy as np from torch.utils.data import Dataset, DataLoader class CWRUBearingDataset(Dataset): def __init__(self, data_path, label_path, normalizeTrue): # 加载预处理好的振动数据和标签要是你手里是mat文件用scipy.io.loadmat转一下就行 self.data np.load(data_path) self.label np.load(label_path) # 归一化到0-1区间防止训练的时候loss直接炸上天 if normalize: self.data (self.data - self.data.min()) / (self.data.max() - self.data.min()) # 1D-CNN的输入要求是 [batch, 通道数, 序列长度]我们的信号是单通道所以加个1维度 self.data torch.tensor(self.data, dtypetorch.float32).unsqueeze(1) self.label torch.tensor(self.label, dtypetorch.long) def __len__(self): return len(self.label) def __getitem__(self, idx): return self.data[idx], self.label[idx] # 举个例子加载源域和目标域源域用负载0的数据目标域用负载1的数据 source_dataset CWRUBearingDataset(./data/source_0_load_data.npy, ./data/source_0_load_label.npy) target_dataset CWRUBearingDataset(./data/target_1_load_data.npy, ./data/target_1_load_label.npy) # 重点源域和目标域的batch size必须一致不然特征拼接的时候会报错 source_loader DataLoader(source_dataset, batch_size32, shuffleTrue, drop_lastTrue) target_loader DataLoader(target_dataset, batch_size32, shuffleTrue, drop_lastTrue)碎碎念我一开始就是没设drop_last导致最后一个batch的数据量不一样训练直接崩了血的教训。还有归一化真的很重要没做之前我的loss直接跑到了几十万训不动一点。第二步1D-CNN特征提取网络我写的是超级简单的两层卷积没有搞什么残差或者复杂的结构新手完全能看懂每一层在干嘛import torch.nn as nn class Simple1DCNN(nn.Module): def __init__(self, num_classes10): super().__init__() # 第一层卷积抓小的振动波动特征比如轴承的冲击脉冲 self.conv1 nn.Conv1d(in_channels1, out_channels16, kernel_size3, stride1, padding1) self.relu1 nn.ReLU() self.pool1 nn.MaxPool1d(kernel_size2, stride2) # 池化降维把序列长度砍半 # 第二层卷积抓更复杂的组合特征 self.conv2 nn.Conv1d(in_channels16, out_channels32, kernel_size3, stride1, padding1) self.relu2 nn.ReLU() self.pool2 nn.MaxPool1d(kernel_size2, stride2) # 全连接层把特征压缩到128维再输出64维的特征用来做域对齐 self.fc nn.Linear(32 * 256, 128) self.feature_layer nn.Linear(128, 64) # 最后加个分类头用来算源域的分类损失 self.classifier nn.Linear(64, num_classes) def forward(self, x): x self.pool1(self.relu1(self.conv1(x))) x self.pool2(self.relu2(self.conv2(x))) # 把二维特征展平成一维方便全连接层处理 x x.view(-1, 32 * 256) x self.fc(x) features self.feature_layer(x) logits self.classifier(features) return features, logits # 自动选择GPU/CPU没有GPU就直接用CPU跑 device cuda if torch.cuda.is_available() else cpu model Simple1DCNN(num_classes10).to(device)碎碎念这里的32*256是我假设原始信号长度是1024经过两次池化后变成了1024/2/2256要是你的信号长度不一样记得改这个数值不然会报形状错误。第三步JDA联合域对齐损失这个是迁移学习的核心我简化了原版JDA的代码新手不用纠结复杂的矩阵运算知道它是用来把源域和目标域的特征拉到一起就行一区top轴承故障诊断迁移学习代码复现 故障诊断代码 复现 首先使用一维的cnn对源域和目标域进行特征提取域适应阶段将源域和目标域作为cnn的输入得到特征然后进行边缘概率分布对齐和条件概率分布对齐也就是进行JDA联合对齐。此域适应方法特别适合初学者了解迁移学习的基础知识。 数据预处理1维数据 网络模型1D-CNN-MMD-Coral 数据集西储大学CWRU 准确率99% 网络框架pytorch 结果输出损失曲线图、准确率曲线图、混淆矩阵、tsne图 使用对象初学者 注意此代码是一个在 GPU 上跑的代码有宝子的电脑只支持 cpu只需要将代码修改成只在 cpu 上跑的就行不光对齐整体的特征分布边缘分布还对齐同一个故障类别的特征分布条件分布比单纯的MMD效果好太多。import torch import torch.nn.functional as F def jda_domain_alignment_loss(source_features, target_features, source_labels, num_classes10): # 把源域和目标域的特征拼在一起 all_features torch.cat([source_features, target_features], dim0) # 用高斯核计算样本之间的相似度 gamma 1.0 pairwise_distance torch.cdist(all_features, all_features, p2) ** 2 kernel_matrix torch.exp(-gamma * pairwise_distance) # 构建联合分布的权重矩阵核心就是让同类样本靠近不同域样本拉远 source_batch source_features.size(0) target_batch target_features.size(0) weight_matrix torch.zeros(source_batch target_batch, source_batch target_batch, devicedevice) # 初始化源域和目标域的整体权重 weight_matrix[:source_batch, :source_batch] 1 / (source_batch ** 2) weight_matrix[source_batch:, source_batch:] 1 / (target_batch ** 2) weight_matrix[:source_batch, source_batch:] -1 / (source_batch * target_batch) weight_matrix[source_batch:, :source_batch] -1 / (source_batch * target_batch) # 加上类内对齐的权重让同一个故障类别的源域和目标域特征更靠近 for class_idx in range(num_classes): source_class_idx (source_labels class_idx).nonzero(as_tupleTrue)[0] target_class_idx (source_labels class_idx).nonzero(as_tupleTrue)[0] source_batch if len(source_class_idx) 0 or len(target_class_idx) 0: continue weight_matrix[source_class_idx[:, None], source_class_idx] 1 / (len(source_class_idx) ** 2) weight_matrix[target_class_idx[:, None], target_class_idx] 1 / (len(target_class_idx) ** 2) weight_matrix[source_class_idx[:, None], target_class_idx] - 1 / (len(source_class_idx) * len(target_class_idx)) weight_matrix[target_class_idx[:, None], source_class_idx] - 1 / (len(source_class_idx) * len(target_class_idx)) # 计算最终的JDA损失 loss torch.trace(torch.matmul(torch.matmul(kernel_matrix, weight_matrix), kernel_matrix.T)) return loss第四步完整训练流程把上面的东西拼在一起就是完整的训练循环了import torch.optim as optim import matplotlib.pyplot as plt # 初始化损失函数和优化器 ce_loss_fn nn.CrossEntropyLoss() optimizer optim.Adam(model.parameters(), lr1e-4) # 超参数大家可以自己调调参就是玄学 lambda_jda 0.5 epochs 50 loss_list [] acc_list [] for epoch in range(epochs): model.train() total_train_loss 0.0 total_train_acc 0.0 # 同时遍历源域和目标域的dataloader for (source_x, source_y), (target_x, target_y) in zip(source_loader, target_loader): source_x, source_y source_x.to(device), source_y.to(device) target_x, target_y target_x.to(device), target_y.to(device) # 提取特征和分类结果 source_feat, source_logits model(source_x) target_feat, target_logits model(target_x) # 计算分类损失和域对齐损失 cls_loss ce_loss_fn(source_logits, source_y) align_loss jda_domain_alignment_loss(source_feat, target_feat, source_y, num_classes10) total_loss cls_loss lambda_jda * align_loss # 反向传播更新参数 optimizer.zero_grad() total_loss.backward() optimizer.step() # 统计指标 total_train_loss total_loss.item() pred torch.argmax(source_logits, dim1) total_train_acc (pred source_y).sum().item() / len(source_y) # 打印每个epoch的结果 avg_loss total_train_loss / len(source_loader) avg_acc total_train_acc / len(source_loader) loss_list.append(avg_loss) acc_list.append(avg_acc) print(fEpoch [{epoch1}/{epochs}] | 平均损失: {avg_loss:.4f} | 源域准确率: {avg_acc:.4f})碎碎念要是你只有CPU就把所有的.to(device)删掉或者改成.cpu()跑起来会慢一点我用1050Ti跑50个epoch大概10分钟CPU的话大概半小时左右。第五步结果可视化训练完之后一定要画图看效果不然你都不知道自己训了个啥from sklearn.manifold import TSNE from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix # 1. 画损失和准确率曲线 plt.figure(figsize(12, 4)) plt.subplot(1, 2, 1) plt.plot(loss_list, label训练损失) plt.xlabel(Epoch) plt.ylabel(Loss) plt.legend() plt.subplot(1, 2, 2) plt.plot(acc_list, label源域准确率) plt.xlabel(Epoch) plt.ylabel(Accuracy) plt.legend() plt.savefig(./train_curve.png) # 2. 画混淆矩阵用目标域的数据测试 model.eval() all_pred [] all_true [] with torch.no_grad(): for x, y in target_loader: x, y x.to(device), y.to(device) _, logits model(x) pred torch.argmax(logits, dim1) all_pred.extend(pred.cpu().numpy()) all_true.extend(y.cpu().numpy()) cm confusion_matrix(all_true, all_pred) disp ConfusionMatrixDisplay(confusion_matrixcm, display_labels[正常, 内圈故障0.007, 外圈故障0.007, 滚动体故障]) disp.plot(cmapplt.cm.Blues) plt.savefig(./confusion_matrix.png) # 3. t-SNE可视化特征对齐效果 all_source_feat [] all_source_label [] all_target_feat [] all_target_label [] with torch.no_grad(): for x, y in source_loader: x, y x.to(device), y.to(device) feat, _ model(x) all_source_feat.extend(feat.cpu().numpy()) all_source_label.extend(y.cpu().numpy()) for x, y in target_loader: x, y x.to(device), y.to(device) feat, _ model(x) all_target_feat.extend(feat.cpu().numpy()) all_target_label.extend(y.cpu().numpy()) all_feat np.concatenate([all_source_feat, all_target_feat], axis0) # 把目标域的标签加10和源域区分开 all_label all_source_label [l10 for l in all_target_label] tsne TSNE(n_components2, random_state42) feat_2d tsne.fit_transform(all_feat) plt.figure(figsize(8, 8)) plt.scatter(feat_2d[:len(all_source_feat), 0], feat_2d[:len(all_source_feat), 1], cblue, label源域, alpha0.5) plt.scatter(feat_2d[len(all_source_feat):, 0], feat_2d[len(all_source_feat):, 1], cred, label目标域, alpha0.5) plt.legend() plt.savefig(./tsne_visualization.png)碎碎念t-SNE图真的太直观了对齐好的话源域和目标域的点会混在一起要是没对齐就是两堆分开的颜色比看准确率爽多了。我这次跑出来的目标域准确率能到99%左右和博主说的差不多。最后唠点踩坑经验源域和目标域的batch size必须一致不然特征拼接会报错1D-CNN的输入一定要加通道维度不然会报形状错误标签一定要搞对我一开始把内圈和外圈的标签搞反了准确率直接跌到50%要是训练的时候loss不下降试试把学习率调小一点或者把lambda_jda改大一点完整的代码我已经打包上传到GitHub了需要的小伙伴直接搜1d-cnn-jda-cwru就能找到有啥问题也可以留言问我看到都会回的。

相关文章:

最近帮实验室刚入门的师弟复现了西储大学轴承故障的迁移学习代码,本来以为是手到擒来的活,结果还是踩了好几个坑,刚好整理出来给同样摸鱼入门的小伙伴参考

一区top轴承故障诊断迁移学习代码复现 故障诊断代码 复现首先使用一维的cnn对源域和目标域进行特征提取,域适应阶段:将源域和目标域作为cnn的输入得到特征,然后进行边缘概率分布对齐和条件概率分布对齐,也就是进行JDA联合对齐。此…...

塔罗牌选框架:准确率超机器学习模型

技术选型困境与创新突破在软件测试领域,技术栈选择一直是核心挑战。传统方法依赖历史数据和机器学习模型,但常陷入“预测陷阱”——过度依赖过往经验导致创新盲区。例如,自动化测试框架的错误选型每年造成巨额损失:38.7%源于技术生…...

2026 年智慧工地排名榜单第一|山东建安物联科技有限公司

2026 年度智慧工地综合实力榜单正式揭晓,山东建安物联科技有限公司(大建安)凭借标准引领、技术实力与标杆项目,登顶全国榜首,成为行业公认的智慧工地领军企业。公司打造的中建八局烟台崆峒胜境项目,获评国家…...

如何快速上手TradingView图表库:15+框架完整集成实战指南

如何快速上手TradingView图表库:15框架完整集成实战指南 【免费下载链接】charting-library-examples Examples of Charting Library integrations with other libraries, frameworks and data transports 项目地址: https://gitcode.com/gh_mirrors/ch/charting-…...

Excel 技巧:一键批量填充空值

🚀 操作步骤选中区域首先,用鼠标选中包含空值的目标数据区域。定位空值按下快捷键 Ctrl G 打开“定位”对话框:点击左下角的 「定位条件...」。选择 「空值」。点击「确定」。✅ 此时,区域内所有空白单元格已被高亮选中。输入公式…...

NaViL-9B效果展示:电商主图自动提取卖点文案+竞品对比分析

NaViL-9B效果展示:电商主图自动提取卖点文案竞品对比分析 1. 多模态大模型惊艳登场 想象一下,当你上传一张商品图片,AI不仅能准确识别图片内容,还能自动生成吸引人的卖点文案——这就是NaViL-9B带来的革命性体验。作为原生多模态…...

Python 3.13 + CUDA 13.0编译轮子

核心工具链安装 1、安装 Visual Studio 2022 (勾选 “使用 C 的桌面开发”) 2、安装 CUDA Toolkit 13.0环境变量注入 在终端执行,确保编译器能精准定位 CUDA 路径:set CUDA_PATHD:\Program Files\NVIDIA_GPU_Computing_Toolkit\v13 set PATH%CUDA_PATH%\…...

League Akari:英雄联盟玩家的终极智能辅助工具实战指南

League Akari:英雄联盟玩家的终极智能辅助工具实战指南 【免费下载链接】League-Toolkit 兴趣使然的、简单易用的英雄联盟工具集。支持战绩查询、自动秒选等功能。基于 LCU API。 项目地址: https://gitcode.com/gh_mirrors/le/League-Toolkit 你是否厌倦了在…...

AI赋能安装流程:快马智能诊断工具,自动解决软件安装兼容性问题

在开发软件的过程中,安装环节往往是第一个拦路虎。特别是当遇到系统环境复杂、依赖库版本冲突、权限配置等问题时,传统的安装方式常常让人头疼不已。最近我在尝试开发一个智能安装问题诊断工具时,发现InsCode(快马)平台的AI辅助功能特别实用&…...

南京四季旅游攻略:最美时节去最美地方

南京四季旅游攻略:最美时节去最美地方 🌸🍃🍂❄️本文作者:南京码农 发布日期:2026年3月26日 关键词:南京旅游、四季景点、旅游攻略、南京必去、季节推荐前言:南京,一座四…...

ESP32 SPI性能调优指南:从80MHz时钟到DMA配置,避开那些坑

ESP32 SPI性能调优实战:突破80MHz时钟与DMA配置的终极指南 当你在ESP32项目中遇到SPI通信速度瓶颈时,是否曾为如何突破80MHz时钟限制而苦恼?是否在配置DMA时踩过各种坑?本文将带你深入ESP32 SPI性能优化的核心领域,从硬…...

AI+医疗从模型到产品:做一个真正可用系统,需要跨过哪些坎?

# AI医疗从模型到产品:做一个真正可用系统,需要跨过哪些坎?做 AI医疗的人,常常会经历一个很像的阶段。前期我们把大部分精力放在模型上:换 backbone、调 loss、做多模态融合、补校准、压错误样本,最后终于把…...

如何用dpkg-architecture解决Debian软件包的多架构依赖问题?

深度解析dpkg-architecture:Debian多架构依赖管理的实战指南 在Debian软件包开发领域,多架构支持一直是开发者面临的复杂挑战之一。随着ARM架构的崛起和异构计算场景的普及,单一架构的软件包已经无法满足现代计算需求。本文将带您深入探索dpk…...

从零部署JetLinks社区版:一站式物联网平台本地化搭建实战

1. JetLinks社区版:物联网开发的瑞士军刀 第一次接触JetLinks社区版是在三年前的一个智能家居项目上。当时客户要求两周内搭建一个能管理5000设备的物联网平台,还要支持自定义协议开发。在对比了多个开源方案后,JetLinks的模块化设计让我眼前…...

解锁Navicat密码:突破加密限制的开源解密工具

解锁Navicat密码:突破加密限制的开源解密工具 【免费下载链接】navicat_password_decrypt 忘记navicat密码时,此工具可以帮您查看密码 项目地址: https://gitcode.com/gh_mirrors/na/navicat_password_decrypt 当数据库连接密码被Navicat加密保存却无法记起&…...

模电小白必看:3种基本放大电路实战对比(附电路图+避坑指南)

模电入门实战:三大基础放大电路深度解析与避坑指南 刚接触模拟电路时,面对共射极、共集极和共基极这三种基本放大电路,很多初学者都会感到困惑——它们看起来相似,但特性却大不相同。本文将用面包板搭建的真实电路和示波器实测波形…...

深入解析服务器License管理:从基础命令到实战应用

1. 服务器License管理:为什么它比你想的更重要 如果你管理过服务器,尤其是那些运行着像CAD、EDA、仿真分析这类专业软件的服务器,那你肯定对“License”这个词不陌生。它就像软件的“通行证”,没有它,再强大的硬件也只…...

纺织抗菌,选对材料才关键

在纺织行业中,抗菌消臭性能是提升产品附加值的核心抓手,其中贴身衣物、家纺等贴身类产品,因长期接触人体或所处环境特性,细菌滋生、异味残留等问题尤为突出。DN128抗菌消臭剂作为高效无机消臭材料,可广泛用作面料及家纺…...

5分钟玩转OpenClaw:nanobot镜像云端体验与本地调试对比

5分钟玩转OpenClaw:nanobot镜像云端体验与本地调试对比 1. 为什么需要对比云端与本地两种体验方式 作为一个长期折腾AI工具的开发者,我最近在测试OpenClaw时遇到了一个典型困境:是直接在本地电脑安装全套环境,还是先用云端沙盒快…...

JAVA重点基础、进阶知识及易错点总结(10)Map 接口(HashMap、LinkedHashMap、TreeMap)

&#x1f680; Java 巩固进阶 第10天 主题&#xff1a;Map 接口深度解析 —— 键值对的高效艺术&#x1f4c5; 进度概览&#xff1a;掌握 Java 中最灵活的数据结构。 &#x1f4a1; 核心价值&#xff1a; 动态数据承载&#xff1a;SpringBoot 中接收前端动态参数 (Map<Stri…...

vue新手福音:快马ai帮你秒建可运行环境,专注学习第一行代码

作为一个刚接触Vue的新手&#xff0c;最让我头疼的就是环境搭建。记得第一次尝试安装Node.js、配置npm、理解脚手架的时候&#xff0c;光是解决各种报错就花了大半天时间。直到发现了InsCode(快马)平台&#xff0c;才明白原来入门可以这么简单。 环境搭建的痛点 传统方式需要先…...

数据开发平台如何落地实操?数据开发平台核心价值是什么?

数据开发平台是企业数字化建设的核心载体&#xff0c;搭建合规高效的数据开发平台&#xff0c;才能打通数据流转全链路&#xff0c;而多数企业落地数据开发平台时&#xff0c;往往陷入流程混乱、效率低下的困境。开始之前给大家分享一份数字化全流程资料包:https://s.fanruan.c…...

UNIGUI 修改网页图标 Delphi

网页图标delphi 软件上方工具栏Project -> Options -> Application -> Icons修改图标点击第一个LoadIcon按钮&#xff0c;然后选择一个你目标的.ioc格式大小是128*128的图标&#xff0c;点击 Save保存即可。服务器运行图标打开ServerModule页面&#xff0c;点击UniSer…...

2026最新Java金三银四面试参考指南公开!

想必有很多小伙伴这会已经在为金三银四面试跳槽做准备了。临近面试肯定是要想办法提升自己的面试能力&#xff0c;这个时候如果还去一昧地提升自己的代码能力对面试是毫无帮助的。大多数人在面试的时候都会遇到以下几种情况&#xff08;大家可以看看自己中了几个&#xff09;&a…...

nli-distilroberta-base前端集成案例:Vue.js构建智能文本分析界面

nli-distilroberta-base前端集成案例&#xff1a;Vue.js构建智能文本分析界面 1. 场景价值与方案概述 电商平台的客服系统每天需要处理大量用户咨询&#xff0c;其中很多问题都涉及产品参数的对比&#xff08;如"这款手机电池容量比A型号大吗&#xff1f;"&#xf…...

大模型赋能多尺度空间智能:从具身感知到地球系统建模的跨学科探索

1. 大模型如何重构空间智能的认知框架 当AlphaGo击败人类棋手时&#xff0c;我们惊叹于AI的策略能力&#xff1b;但当大语言模型开始理解三维空间关系时&#xff0c;这标志着机器认知的质变。空间智能的本质是理解物体间的相对位置、距离和运动规律&#xff0c;这种能力对人类而…...

Unity游戏开发:A*寻路算法实战,5步搞定NPC智能移动(附完整Demo)

Unity游戏开发&#xff1a;A*寻路算法实战指南与高级优化技巧 在游戏开发中&#xff0c;NPC的智能移动一直是开发者需要解决的核心问题之一。想象一下&#xff0c;当玩家在《魔兽世界》中穿越荆棘谷时&#xff0c;那些巡逻的巨魔守卫是如何绕过树木和山丘找到最短路径的&#x…...

告别鉴权内耗,让每一位Java开发者都能轻松上手

写Java的这些年&#xff0c;无论是初入职场的新手&#xff0c;还是深耕多年的老兵&#xff0c;谁没在「鉴权」上栽过跟头&#xff1f; 熬夜啃Spring Security的复杂配置&#xff0c;对着一堆过滤器链抓耳挠腮&#xff1b;用Shiro做前后端分离项目&#xff0c;为了适配Token模式…...

项目分享|LLM驱动的多市场股票智能分析器

项目分享|LLM驱动的多市场股票智能分析器 引言 在股票投资分析中&#xff0c;实时行情跟踪、多维度数据解析和科学决策判断是核心需求&#xff0c;而个人投资者往往面临数据分散、分析耗时、缺乏专业工具的问题。由ZhuLinsen开源的daily_stock_analysis项目完美解决了这些痛点…...

PT工具效率革命:一站式解决PT站点种子管理难题

PT工具效率革命&#xff1a;一站式解决PT站点种子管理难题 【免费下载链接】PT-Plugin-Plus PT 助手 Plus&#xff0c;为 Microsoft Edge、Google Chrome、Firefox 浏览器插件&#xff08;Web Extensions&#xff09;&#xff0c;主要用于辅助下载 PT 站的种子。 项目地址: h…...