当前位置: 首页 > news >正文

Graph U-Net Code【图分类】

1. main.py


# GNet是需要用到的model
net = GNet(G_data.feat_dim, G_data.num_class, args) # graph, 特征维度,类别数,参数
trainer = Trainer(args, net, G_data) #开始训练数据
# 正式开始训练数据
trainer.train()

2. network.py

class GNet(nn.Module):def __init__(self, in_dim, n_classes, args):super(GNet, self).__init__()self.n_act = getattr(nn, args.act_n)()# getattr() 是 Python 内置的一个函数,可以用来获取一个对象的属性值或方法self.c_act = getattr(nn, args.act_c)()# print('GNet1: in_dim=', in_dim, 'n_class=',n_classes)  # GNet1: in_dim= 82 n_class= 2"用的是GCN的框架,输入分别是feat dim、layer dim、network act、drop net(net表示GCN网络本身的参数)"self.s_gcn = GCN(in_dim, args.l_dim, self.n_act, args.drop_n)self.g_unet = GraphUnet(args.ks, args.l_dim, args.l_dim, args.l_dim, self.n_act, args.drop_n)"""nn.Linear定义一个神经网络的线性层,方法如下:torch.nn.Linear(in_features, # 输入的神经元个数out_features, # 输出神经元个数bias=True # 是否包含偏置)"""self.out_l_1 = nn.Linear(3*args.l_dim*(args.l_num+1), args.h_dim)self.out_l_2 = nn.Linear(args.h_dim, n_classes)"nn.Dropout(p = 0.3) # 表示每个神经元有0.3的可能性不被激活"self.out_drop = nn.Dropout(p=args.drop_c)Initializer.weights_init(self)def forward(self, gs, hs, labels):print('GNet2: gs=',type(gs), len(gs), 'hs=',type(hs), len(hs), 'labels:',type(labels),labels.shape)# GNet2: gs= <class 'list'> 32 hs= <class 'list'> 32 labels: <class 'torch.Tensor'> torch.Size([32])hs = self.embed(gs, hs)print('GNet2: hs=', type(hs), hs.shape)logits = self.classify(hs)return self.metric(logits, labels)

3. trainer.py

class Trainer:"init初始化,输入分别是arg参数、gcn net、graph Data,将这些装进self里面"def __init__(self, args, net, G_data):self.args = argsself.net = netself.feat_dim = G_data.feat_dimself.fold_idx = G_data.fold_idxself.init(args, G_data.train_gs, G_data.test_gs)# 若是有显卡,则用显卡跑if torch.cuda.is_available():self.net.cuda()"初始化——开始训练数据"def init(self, args, train_gs, test_gs):print('#train: %d, #test: %d' % (len(train_gs), len(test_gs)))# 分成训练集和测试集,记载数据train_data = GraphData(train_gs, self.feat_dim)test_data = GraphData(test_gs, self.feat_dim)# DataLoader 为pytorch 内部类,此时只需要指定trainset, batch_size, shuffle, num_workers, ...等self.train_d = train_data.loader(self.args.batch, True)self.test_d = test_data.loader(self.args.batch, False)self.optimizer = optim.Adam(self.net.parameters(), lr=self.args.lr, amsgrad=True,weight_decay=0.0008)
    def train(self):max_acc = 0.0train_str = 'Train epoch %d: loss %.5f acc %.5f'test_str = 'Test epoch %d: loss %.5f acc %.5f max %.5f'line_str = '%d:\t%.5f\n'for e_id in range(self.args.num_epochs):self.net.train()# 从每个epoch开始训练loss, acc = self.run_epoch(e_id, self.train_d, self.net, self.optimizer)print(train_str % (e_id, loss, acc))with torch.no_grad():self.net.eval()loss, acc = self.run_epoch(e_id, self.test_d, self.net, None)max_acc = max(max_acc, acc)print(test_str % (e_id, loss, acc, max_acc))with open(self.args.acc_file, 'a+') as f:f.write(line_str % (self.fold_idx, max_acc))
    def run_epoch(self, epoch, data, model, optimizer):#self.run_epoch(e_id, self.train_d, self.net, self.optimizer)losses, accs, n_samples = [], [], 0for batch in tqdm(data, desc=str(epoch), unit='b'):cur_len, gs, hs, ys = batchgs, hs, ys = map(self.to_cuda, [gs, hs, ys])loss, acc = model(gs, hs, ys)losses.append(loss*cur_len)accs.append(acc*cur_len)n_samples += cur_lenif optimizer is not None:optimizer.zero_grad()loss.backward()optimizer.step()avg_loss, avg_acc = sum(losses) / n_samples, sum(accs) / n_samplesreturn avg_loss.item(), avg_acc.item()

不懂

class GraphConvolution(Module):"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_features"""为啥要这么做???5555555555555555555555555555"""self.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, input, adj):support = torch.mm(input, self.weight)output = torch.spmm(adj, support)if self.bias is not None:return output + self.biaselse:return output

相关文章:

Graph U-Net Code【图分类】

1. main.py # GNet是需要用到的model net GNet(G_data.feat_dim, G_data.num_class, args) # graph, 特征维度&#xff0c;类别数&#xff0c;参数 trainer Trainer(args, net, G_data) #开始训练数据 # 正式开始训练数据 trainer.train()2. network.py class GNet(nn.Modul…...

PTA 秀恩爱分得快(树)

题目 古人云&#xff1a;秀恩爱&#xff0c;分得快。 互联网上每天都有大量人发布大量照片&#xff0c;我们通过分析这些照片&#xff0c;可以分析人与人之间的亲密度。如果一张照片上出现了 K 个人&#xff0c;这些人两两间的亲密度就被定义为 1/K。任意两个人如果同时出现在…...

文心一言4.0对比ChatGPT4.0有什么优势?

目录 总结 文心一言4.0的优势 文心一言4.0的劣势 免费分享使用工具 后话 生成式AI的困境 “不会问”“不会用”“不敢信” 为什么要出收费版本&#xff1f; 目前使用过国内的文心一言3.5和WPS AI&#xff0c;国外的ChatGPT4.0。 文心一言和其他国内产品相比&#xff0…...

美观且可以很方便自定义的MATLAB绘图颜色

函数介绍 主函数是draw_test&#xff0c;用于测试函数。 draw_h是函数&#xff0c;用于给Matlab提供美观且可以很方便自定义的绘图颜色。 draw_h函数介绍 这是一个带输入输出的函数&#xff0c;输入1/2/3&#xff0c;输出下面三种颜色库的配色&#xff0c;每种库均有五种颜色…...

基于jsp,ssm物流快递管理系统

开发工具&#xff1a;eclipse&#xff0c;jdk1.8 服务器&#xff1a;tomcat7.0 数据库&#xff1a;mysql5.7 技术&#xff1a; springspringMVCmybaitsEasyUI 项目包括用户前台和管理后台两部分&#xff0c;功能介绍如下&#xff1a; 一、用户(前台)功能&#xff1a; 用…...

陪诊系统|挂号陪护搭建二开陪诊师入驻就医小程序

我们的陪诊小程序拥有丰富多样的功能&#xff0c;旨在最大程度满足现代人的需求。首先&#xff0c;我们采用了智能排队系统&#xff0c;通过扫描二维码获取排号信息&#xff0c;让您从繁琐的排队过程中解放出来。其次&#xff0c;我们提供了多种支付方式&#xff0c;不仅可以实…...

恒驰服务 | 华为云数据使能专家服务offering之大数据建设

恒驰大数据服务主要针对客户在进行智能数据迁移的过程中&#xff0c;存在业务停机、数据丢失、迁移周期紧张、运维成本高等问题&#xff0c;通过为客户提供迁移调研、方案设计、迁移实施、迁移验收等服务内容&#xff0c;支撑客户实现快速稳定上云&#xff0c;有效降低时间成本…...

轻量级狂雨小说cms系统源码 v1.5.2 基于ThinkPHP5.1+MySQL

轻量级狂雨小说cms系统源码 v1.5.2 基于ThinkPHP5.1MySQL的技术开发 狂雨小说cms提供一个轻量级小说网站解决方案&#xff0c;基于ThinkPHP5.1MySQL的技术开发。 KYXSCMS,灵活&#xff0c;方便&#xff0c;人性化设计简单易用是最大的特色&#xff0c;是快速架设小说类网站首选…...

Leetcode刷题详解——Pow(x, n)

1. 题目链接&#xff1a;50. Pow(x, n) 2. 题目描述&#xff1a; 实现 pow(x, n) &#xff0c;即计算 x 的整数 n 次幂函数&#xff08;即&#xff0c;xn &#xff09;。 示例 1&#xff1a; 输入&#xff1a;x 2.00000, n 10 输出&#xff1a;1024.00000示例 2&#xff1a;…...

计算机毕业设计选题推荐-校园失物招领微信小程序/安卓APP-项目实战

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…...

人工智能基础_机器学习011_梯度下降概念_梯度下降步骤_函数与导函数求解最优解---人工智能工作笔记0051

然后我们来看一下梯度下降,这里先看一个叫 无约束最优化问题,,值得是从一个问题的所有可能的备选方案中选最优的方案, 我们的知道,我们的正态分布这里,正规的一个正态分布,还有我们的正规方程,他的这个x,是正规的,比如上面画的这个曲线,他的这个x,就是大于0的对吧,而现实生活…...

开放式耳机能保护听力吗,开放式耳机跟骨传导耳机哪个更好?

如果从严格意义上来讲的话&#xff0c;开放式耳机中的骨传导耳机是能保护听力&#xff0c;现如今的开放式耳机是一个统称&#xff0c;将所有不入耳的类目全部规划到一块。因此在开放式耳机中存在着一些耳机是只能够保持周边环境音&#xff0c;而不是保护听力的。 下面让我来给…...

【Qt之QLocale】使用

描述 QLocale类可以在多种语言之间进行数字和字符串的转换。 QLocale类在构造函数中使用语言/国家对进行初始化&#xff0c;并提供类似于QString中的数字转字符串和字符串转数字的转换函数。 示例&#xff1a; QLocale egyptian(QLocale::Arabic, QLocale::Egypt);QString s1 …...

维修服务预约小程序的效果如何

生活服务中维修项目绝对是需求量很高的&#xff0c;如常见的保洁、管道疏通、数码维修、安装、便民服务等&#xff0c;可以说每天都有生意&#xff0c;而对相关维修店企业来说&#xff0c;如何获得更多生意很重要。 接下来让我们看看通过【雨科】平台制作维修服务预约小程序能…...

前端架构体系调研整理汇总

1.公司研发人数与前端体系 小型创业公司 前端人数&#xff1a; < 3 人 产品类型&#xff1a; 产品不是非常成熟&#xff0c;比较新颖。 项目流程&#xff1a;不完善&#xff0c;快、紧促&#xff0c;没有固定的时间排期。 技术栈&#xff1a; 没有历史包袱&#xff0c;技…...

DrawerLayout的点击事件会穿透到底部,如何拦截?

DrawerLayout实现侧后&#xff0c;发现了一个问题。点击DrawerLayout的画面&#xff0c;会触发覆盖的底层页面的控件。由此说明点击事件穿透到了底部。但是我只需要触发抽屉布局里的控件&#xff0c;不想触发底层被覆盖的看不见的按钮&#xff0c;由此我想到的时让抽屉页面拦截…...

在Spring boot中 使用JWT和过滤器实现登录认证

在Spring boot中 使用JWT和过滤器实现登录认证 一、登录获得JWT 在navicat中运行如下sql,准备一张user表 -- ---------------------------- -- Table structure for t_user -- ---------------------------- DROP TABLE IF EXISTS t_user; CREATE TABLE t_user (id int(11) …...

天堂2如何对版本里面的内容进行修改

天堂2写装备属性的问题 早一点的版本属性都是写在armor文件夹 xml档里&#xff0c;不再写armor里了 armor文件夹里只有防御 HP MP增加量&#xff0c;套装的属性都用一个技能形式写在 skills里了 在配合数据库里一个叫armorsets实现套装属性&#xff0c;拿皇家套做说明。 id 43…...

代码随想录Day33 LeetCode T62不同路径 LeetCode T63 不同路径II

前言 动规五部曲 1.确定dp数组含义 2.确定递推公式 3.初始化数组 4.确定遍历方式 5.打印dp数组查看分析问题 LeetCode T62 不同路径 题目链接:62. 不同路径 - 力扣&#xff08;LeetCode&#xff09; 题目思路: 注:n行m列而不是m行n列 1.确定dp数组含义 代表到达此下标有多少条…...

【计算机网络】分层模型和应用协议

网络分层模型和应用协议 1. 分层模型 1.1 五层网络模型 网络要解决的问题是&#xff1a;两个程序之间如何交换数据。 四层&#xff1f;五层&#xff1f;七层&#xff1f; 2. 应用层协议 2.1 URL URL&#xff08;uniform resource locator&#xff0c;统一资源定位符&#…...

挑战杯推荐项目

“人工智能”创意赛 - 智能艺术创作助手&#xff1a;借助大模型技术&#xff0c;开发能根据用户输入的主题、风格等要求&#xff0c;生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用&#xff0c;帮助艺术家和创意爱好者激发创意、提高创作效率。 ​ - 个性化梦境…...

生成xcframework

打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式&#xff0c;可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

Day131 | 灵神 | 回溯算法 | 子集型 子集

Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; 笔者写过很多次这道题了&#xff0c;不想写题解了&#xff0c;大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...

linux arm系统烧录

1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 &#xff08;忘了有没有这步了 估计有&#xff09; 刷机程序 和 镜像 就不提供了。要刷的时…...

Neo4j 集群管理:原理、技术与最佳实践深度解析

Neo4j 的集群技术是其企业级高可用性、可扩展性和容错能力的核心。通过深入分析官方文档,本文将系统阐述其集群管理的核心原理、关键技术、实用技巧和行业最佳实践。 Neo4j 的 Causal Clustering 架构提供了一个强大而灵活的基石,用于构建高可用、可扩展且一致的图数据库服务…...

OpenLayers 分屏对比(地图联动)

注&#xff1a;当前使用的是 ol 5.3.0 版本&#xff0c;天地图使用的key请到天地图官网申请&#xff0c;并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能&#xff0c;和卷帘图层不一样的是&#xff0c;分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...

【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)

1.获取 authorizationCode&#xff1a; 2.利用 authorizationCode 获取 accessToken&#xff1a;文档中心 3.获取手机&#xff1a;文档中心 4.获取昵称头像&#xff1a;文档中心 首先创建 request 若要获取手机号&#xff0c;scope必填 phone&#xff0c;permissions 必填 …...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中&#xff0c;CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时&#xff0c;通常会导致应用响应缓慢&#xff0c;甚至服务不可用&#xff0c;严重影响用户体验和业务运行。因此&#xff0c;掌握一套科学有效的CPU飙高问题排查方法&…...

深度学习习题2

1.如果增加神经网络的宽度&#xff0c;精确度会增加到一个特定阈值后&#xff0c;便开始降低。造成这一现象的可能原因是什么&#xff1f; A、即使增加卷积核的数量&#xff0c;只有少部分的核会被用作预测 B、当卷积核数量增加时&#xff0c;神经网络的预测能力会降低 C、当卷…...

从零开始了解数据采集(二十八)——制造业数字孪生

近年来&#xff0c;我国的工业领域正经历一场前所未有的数字化变革&#xff0c;从“双碳目标”到工业互联网平台的推广&#xff0c;国家政策和市场需求共同推动了制造业的升级。在这场变革中&#xff0c;数字孪生技术成为备受关注的关键工具&#xff0c;它不仅让企业“看见”设…...