当前位置: 首页 > news >正文

Graph U-Net Code【图分类】

1. main.py


# GNet是需要用到的model
net = GNet(G_data.feat_dim, G_data.num_class, args) # graph, 特征维度,类别数,参数
trainer = Trainer(args, net, G_data) #开始训练数据
# 正式开始训练数据
trainer.train()

2. network.py

class GNet(nn.Module):def __init__(self, in_dim, n_classes, args):super(GNet, self).__init__()self.n_act = getattr(nn, args.act_n)()# getattr() 是 Python 内置的一个函数,可以用来获取一个对象的属性值或方法self.c_act = getattr(nn, args.act_c)()# print('GNet1: in_dim=', in_dim, 'n_class=',n_classes)  # GNet1: in_dim= 82 n_class= 2"用的是GCN的框架,输入分别是feat dim、layer dim、network act、drop net(net表示GCN网络本身的参数)"self.s_gcn = GCN(in_dim, args.l_dim, self.n_act, args.drop_n)self.g_unet = GraphUnet(args.ks, args.l_dim, args.l_dim, args.l_dim, self.n_act, args.drop_n)"""nn.Linear定义一个神经网络的线性层,方法如下:torch.nn.Linear(in_features, # 输入的神经元个数out_features, # 输出神经元个数bias=True # 是否包含偏置)"""self.out_l_1 = nn.Linear(3*args.l_dim*(args.l_num+1), args.h_dim)self.out_l_2 = nn.Linear(args.h_dim, n_classes)"nn.Dropout(p = 0.3) # 表示每个神经元有0.3的可能性不被激活"self.out_drop = nn.Dropout(p=args.drop_c)Initializer.weights_init(self)def forward(self, gs, hs, labels):print('GNet2: gs=',type(gs), len(gs), 'hs=',type(hs), len(hs), 'labels:',type(labels),labels.shape)# GNet2: gs= <class 'list'> 32 hs= <class 'list'> 32 labels: <class 'torch.Tensor'> torch.Size([32])hs = self.embed(gs, hs)print('GNet2: hs=', type(hs), hs.shape)logits = self.classify(hs)return self.metric(logits, labels)

3. trainer.py

class Trainer:"init初始化,输入分别是arg参数、gcn net、graph Data,将这些装进self里面"def __init__(self, args, net, G_data):self.args = argsself.net = netself.feat_dim = G_data.feat_dimself.fold_idx = G_data.fold_idxself.init(args, G_data.train_gs, G_data.test_gs)# 若是有显卡,则用显卡跑if torch.cuda.is_available():self.net.cuda()"初始化——开始训练数据"def init(self, args, train_gs, test_gs):print('#train: %d, #test: %d' % (len(train_gs), len(test_gs)))# 分成训练集和测试集,记载数据train_data = GraphData(train_gs, self.feat_dim)test_data = GraphData(test_gs, self.feat_dim)# DataLoader 为pytorch 内部类,此时只需要指定trainset, batch_size, shuffle, num_workers, ...等self.train_d = train_data.loader(self.args.batch, True)self.test_d = test_data.loader(self.args.batch, False)self.optimizer = optim.Adam(self.net.parameters(), lr=self.args.lr, amsgrad=True,weight_decay=0.0008)
    def train(self):max_acc = 0.0train_str = 'Train epoch %d: loss %.5f acc %.5f'test_str = 'Test epoch %d: loss %.5f acc %.5f max %.5f'line_str = '%d:\t%.5f\n'for e_id in range(self.args.num_epochs):self.net.train()# 从每个epoch开始训练loss, acc = self.run_epoch(e_id, self.train_d, self.net, self.optimizer)print(train_str % (e_id, loss, acc))with torch.no_grad():self.net.eval()loss, acc = self.run_epoch(e_id, self.test_d, self.net, None)max_acc = max(max_acc, acc)print(test_str % (e_id, loss, acc, max_acc))with open(self.args.acc_file, 'a+') as f:f.write(line_str % (self.fold_idx, max_acc))
    def run_epoch(self, epoch, data, model, optimizer):#self.run_epoch(e_id, self.train_d, self.net, self.optimizer)losses, accs, n_samples = [], [], 0for batch in tqdm(data, desc=str(epoch), unit='b'):cur_len, gs, hs, ys = batchgs, hs, ys = map(self.to_cuda, [gs, hs, ys])loss, acc = model(gs, hs, ys)losses.append(loss*cur_len)accs.append(acc*cur_len)n_samples += cur_lenif optimizer is not None:optimizer.zero_grad()loss.backward()optimizer.step()avg_loss, avg_acc = sum(losses) / n_samples, sum(accs) / n_samplesreturn avg_loss.item(), avg_acc.item()

不懂

class GraphConvolution(Module):"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_features"""为啥要这么做???5555555555555555555555555555"""self.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, input, adj):support = torch.mm(input, self.weight)output = torch.spmm(adj, support)if self.bias is not None:return output + self.biaselse:return output

相关文章:

Graph U-Net Code【图分类】

1. main.py # GNet是需要用到的model net GNet(G_data.feat_dim, G_data.num_class, args) # graph, 特征维度&#xff0c;类别数&#xff0c;参数 trainer Trainer(args, net, G_data) #开始训练数据 # 正式开始训练数据 trainer.train()2. network.py class GNet(nn.Modul…...

PTA 秀恩爱分得快(树)

题目 古人云&#xff1a;秀恩爱&#xff0c;分得快。 互联网上每天都有大量人发布大量照片&#xff0c;我们通过分析这些照片&#xff0c;可以分析人与人之间的亲密度。如果一张照片上出现了 K 个人&#xff0c;这些人两两间的亲密度就被定义为 1/K。任意两个人如果同时出现在…...

文心一言4.0对比ChatGPT4.0有什么优势?

目录 总结 文心一言4.0的优势 文心一言4.0的劣势 免费分享使用工具 后话 生成式AI的困境 “不会问”“不会用”“不敢信” 为什么要出收费版本&#xff1f; 目前使用过国内的文心一言3.5和WPS AI&#xff0c;国外的ChatGPT4.0。 文心一言和其他国内产品相比&#xff0…...

美观且可以很方便自定义的MATLAB绘图颜色

函数介绍 主函数是draw_test&#xff0c;用于测试函数。 draw_h是函数&#xff0c;用于给Matlab提供美观且可以很方便自定义的绘图颜色。 draw_h函数介绍 这是一个带输入输出的函数&#xff0c;输入1/2/3&#xff0c;输出下面三种颜色库的配色&#xff0c;每种库均有五种颜色…...

基于jsp,ssm物流快递管理系统

开发工具&#xff1a;eclipse&#xff0c;jdk1.8 服务器&#xff1a;tomcat7.0 数据库&#xff1a;mysql5.7 技术&#xff1a; springspringMVCmybaitsEasyUI 项目包括用户前台和管理后台两部分&#xff0c;功能介绍如下&#xff1a; 一、用户(前台)功能&#xff1a; 用…...

陪诊系统|挂号陪护搭建二开陪诊师入驻就医小程序

我们的陪诊小程序拥有丰富多样的功能&#xff0c;旨在最大程度满足现代人的需求。首先&#xff0c;我们采用了智能排队系统&#xff0c;通过扫描二维码获取排号信息&#xff0c;让您从繁琐的排队过程中解放出来。其次&#xff0c;我们提供了多种支付方式&#xff0c;不仅可以实…...

恒驰服务 | 华为云数据使能专家服务offering之大数据建设

恒驰大数据服务主要针对客户在进行智能数据迁移的过程中&#xff0c;存在业务停机、数据丢失、迁移周期紧张、运维成本高等问题&#xff0c;通过为客户提供迁移调研、方案设计、迁移实施、迁移验收等服务内容&#xff0c;支撑客户实现快速稳定上云&#xff0c;有效降低时间成本…...

轻量级狂雨小说cms系统源码 v1.5.2 基于ThinkPHP5.1+MySQL

轻量级狂雨小说cms系统源码 v1.5.2 基于ThinkPHP5.1MySQL的技术开发 狂雨小说cms提供一个轻量级小说网站解决方案&#xff0c;基于ThinkPHP5.1MySQL的技术开发。 KYXSCMS,灵活&#xff0c;方便&#xff0c;人性化设计简单易用是最大的特色&#xff0c;是快速架设小说类网站首选…...

Leetcode刷题详解——Pow(x, n)

1. 题目链接&#xff1a;50. Pow(x, n) 2. 题目描述&#xff1a; 实现 pow(x, n) &#xff0c;即计算 x 的整数 n 次幂函数&#xff08;即&#xff0c;xn &#xff09;。 示例 1&#xff1a; 输入&#xff1a;x 2.00000, n 10 输出&#xff1a;1024.00000示例 2&#xff1a;…...

计算机毕业设计选题推荐-校园失物招领微信小程序/安卓APP-项目实战

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…...

人工智能基础_机器学习011_梯度下降概念_梯度下降步骤_函数与导函数求解最优解---人工智能工作笔记0051

然后我们来看一下梯度下降,这里先看一个叫 无约束最优化问题,,值得是从一个问题的所有可能的备选方案中选最优的方案, 我们的知道,我们的正态分布这里,正规的一个正态分布,还有我们的正规方程,他的这个x,是正规的,比如上面画的这个曲线,他的这个x,就是大于0的对吧,而现实生活…...

开放式耳机能保护听力吗,开放式耳机跟骨传导耳机哪个更好?

如果从严格意义上来讲的话&#xff0c;开放式耳机中的骨传导耳机是能保护听力&#xff0c;现如今的开放式耳机是一个统称&#xff0c;将所有不入耳的类目全部规划到一块。因此在开放式耳机中存在着一些耳机是只能够保持周边环境音&#xff0c;而不是保护听力的。 下面让我来给…...

【Qt之QLocale】使用

描述 QLocale类可以在多种语言之间进行数字和字符串的转换。 QLocale类在构造函数中使用语言/国家对进行初始化&#xff0c;并提供类似于QString中的数字转字符串和字符串转数字的转换函数。 示例&#xff1a; QLocale egyptian(QLocale::Arabic, QLocale::Egypt);QString s1 …...

维修服务预约小程序的效果如何

生活服务中维修项目绝对是需求量很高的&#xff0c;如常见的保洁、管道疏通、数码维修、安装、便民服务等&#xff0c;可以说每天都有生意&#xff0c;而对相关维修店企业来说&#xff0c;如何获得更多生意很重要。 接下来让我们看看通过【雨科】平台制作维修服务预约小程序能…...

前端架构体系调研整理汇总

1.公司研发人数与前端体系 小型创业公司 前端人数&#xff1a; < 3 人 产品类型&#xff1a; 产品不是非常成熟&#xff0c;比较新颖。 项目流程&#xff1a;不完善&#xff0c;快、紧促&#xff0c;没有固定的时间排期。 技术栈&#xff1a; 没有历史包袱&#xff0c;技…...

DrawerLayout的点击事件会穿透到底部,如何拦截?

DrawerLayout实现侧后&#xff0c;发现了一个问题。点击DrawerLayout的画面&#xff0c;会触发覆盖的底层页面的控件。由此说明点击事件穿透到了底部。但是我只需要触发抽屉布局里的控件&#xff0c;不想触发底层被覆盖的看不见的按钮&#xff0c;由此我想到的时让抽屉页面拦截…...

在Spring boot中 使用JWT和过滤器实现登录认证

在Spring boot中 使用JWT和过滤器实现登录认证 一、登录获得JWT 在navicat中运行如下sql,准备一张user表 -- ---------------------------- -- Table structure for t_user -- ---------------------------- DROP TABLE IF EXISTS t_user; CREATE TABLE t_user (id int(11) …...

天堂2如何对版本里面的内容进行修改

天堂2写装备属性的问题 早一点的版本属性都是写在armor文件夹 xml档里&#xff0c;不再写armor里了 armor文件夹里只有防御 HP MP增加量&#xff0c;套装的属性都用一个技能形式写在 skills里了 在配合数据库里一个叫armorsets实现套装属性&#xff0c;拿皇家套做说明。 id 43…...

代码随想录Day33 LeetCode T62不同路径 LeetCode T63 不同路径II

前言 动规五部曲 1.确定dp数组含义 2.确定递推公式 3.初始化数组 4.确定遍历方式 5.打印dp数组查看分析问题 LeetCode T62 不同路径 题目链接:62. 不同路径 - 力扣&#xff08;LeetCode&#xff09; 题目思路: 注:n行m列而不是m行n列 1.确定dp数组含义 代表到达此下标有多少条…...

【计算机网络】分层模型和应用协议

网络分层模型和应用协议 1. 分层模型 1.1 五层网络模型 网络要解决的问题是&#xff1a;两个程序之间如何交换数据。 四层&#xff1f;五层&#xff1f;七层&#xff1f; 2. 应用层协议 2.1 URL URL&#xff08;uniform resource locator&#xff0c;统一资源定位符&#…...

Docker 离线安装指南

参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性&#xff0c;不同版本的Docker对内核版本有不同要求。例如&#xff0c;Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本&#xff0c;Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个&#xff1f;3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制&#xff08;过半机制&#xff0…...

【机器视觉】单目测距——运动结构恢复

ps&#xff1a;图是随便找的&#xff0c;为了凑个封面 前言 在前面对光流法进行进一步改进&#xff0c;希望将2D光流推广至3D场景流时&#xff0c;发现2D转3D过程中存在尺度歧义问题&#xff0c;需要补全摄像头拍摄图像中缺失的深度信息&#xff0c;否则解空间不收敛&#xf…...

定时器任务——若依源码分析

分析util包下面的工具类schedule utils&#xff1a; ScheduleUtils 是若依中用于与 Quartz 框架交互的工具类&#xff0c;封装了定时任务的 创建、更新、暂停、删除等核心逻辑。 createScheduleJob createScheduleJob 用于将任务注册到 Quartz&#xff0c;先构建任务的 JobD…...

【论文笔记】若干矿井粉尘检测算法概述

总的来说&#xff0c;传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度&#xff0c;通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...

Springboot社区养老保险系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;社区养老保险系统小程序被用户普遍使用&#xff0c;为方…...

用鸿蒙HarmonyOS5实现中国象棋小游戏的过程

下面是一个基于鸿蒙OS (HarmonyOS) 的中国象棋小游戏的实现代码。这个实现使用Java语言和鸿蒙的Ability框架。 1. 项目结构 /src/main/java/com/example/chinesechess/├── MainAbilitySlice.java // 主界面逻辑├── ChessView.java // 游戏视图和逻辑├──…...

《Docker》架构

文章目录 架构模式单机架构应用数据分离架构应用服务器集群架构读写分离/主从分离架构冷热分离架构垂直分库架构微服务架构容器编排架构什么是容器&#xff0c;docker&#xff0c;镜像&#xff0c;k8s 架构模式 单机架构 单机架构其实就是应用服务器和单机服务器都部署在同一…...

【2D与3D SLAM中的扫描匹配算法全面解析】

引言 扫描匹配(Scan Matching)是同步定位与地图构建(SLAM)系统中的核心组件&#xff0c;它通过对齐连续的传感器观测数据来估计机器人的运动。本文将深入探讨2D和3D SLAM中的各种扫描匹配算法&#xff0c;包括数学原理、实现细节以及实际应用中的性能对比&#xff0c;特别关注…...

code-server安装使用,并配置frp反射域名访问

为什么使用 code-server是VSCode网页版开发软件&#xff0c;可以在浏览器访问编程&#xff0c;可以使用vscode中的插件。如果有自己的服务器&#xff0c;使用frp透传后&#xff0c;域名访问在线编程&#xff0c;使用方便&#xff0c;打开的服务端口不需要单独配置&#xff0c;可…...