当前位置: 首页 > news >正文

【Graph Net学习】LINE实现Graph Embedding

一、简介

        LINE (Large-scale Information Network Embedding,2015) 是一种设计用于处理大规模信息网络的算法。它主要的目标是在给定的大规模信息网络中学习高质量的节点嵌入,并尽量保留网络中信息的丰富性。其具体的表现为在一个低 维空间里以向量形式表示网络中的节点,以便后续的机器学习任务可以更好地理解。

LINE算法根据两种相互关联的线性化策略去处理信息图,分别考虑了图节点的一阶邻居和二阶邻居。通过这种方式,LINE既能反映出网络的全局属性又能反映出网络的局部属性。

        调用算法流程如下:

  1. 首先,为图中的每个节点初始化一个随机向量。

  2. 接着,使用一阶邻居的优化原型函数进行训练。在一阶近邻策略中,若两个节点存在直接连接,则他们的向量应该尽可能相近。

  3. 然后,使用二阶邻居的优化原型函数进行训练。在二阶近邻策略中,考虑两节点间的间接联系。例如,若两节点存在共享的邻居,即使他们之间没有直接的联系,他们的向量也应该相近。

  4. 对每个节点,计算其在一阶和二阶优化下的损失函数值,并对其进行优化。

  5. 优化完成后,此时每个节点上的向量就是最终的嵌入表示。

  6. 基于得到的嵌入表示进行后续的分析或机器学习任务。

        接下来就是快乐的代码时间嘿嘿嘿

二、代码

import os
import pandas as pd
import numpy as np
import networkx as nx
import time
import scipy.sparse as sp
from torch_geometric.data import Data
from torch_geometric.transforms import ToSparseTensor
import torch_geometric.utils
from sklearn.preprocessing import LabelEncoderimport torch
import torch.nn as nn#配置项
class configs():def __init__(self):# Dataself.data_path = r'./data'self.save_model_dir = r'./'self.num_nodes = 2708self.embedding_dim = 128self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.learning_rate = 0.01self.epoch = 30self.criterion = nn.BCEWithLogitsLoss()self.istrain = Trueself.istest = Truecfg = configs()def load_cora_data(data_path = './data/cora'):content_df = pd.read_csv(os.path.join(data_path,"cora.content"), delimiter="\t", header=None)content_df.set_index(0, inplace=True)index = content_df.index.tolist()features = sp.csr_matrix(content_df.values[:,:-1], dtype=np.float32)# 处理标签labels = content_df.values[:,-1]class_encoder = LabelEncoder()labels = class_encoder.fit_transform(labels)# 读取引用关系cites_df = pd.read_csv(os.path.join(data_path,"cora.cites"), delimiter="\t", header=None)cites_df[0] = cites_df[0].astype(str)cites_df[1] = cites_df[1].astype(str)cites = [tuple(x) for x in cites_df.values]edges = [(index.index(int(cite[0])), index.index(int(cite[1]))) for cite in cites]edges = np.array(edges).T# 构造Data对象data = Data(x=torch.from_numpy(np.array(features.todense())),edge_index=torch.LongTensor(edges),y=torch.from_numpy(labels))idx_train = range(140)idx_val = range(200, 500)idx_test = range(500, 1500)# 读取Cora数据集 return geometric Data格式def index_to_mask(index, size):mask = np.zeros(size, dtype=bool)mask[index] = Truereturn maskdata.train_mask = index_to_mask(idx_train, size=labels.shape[0])data.val_mask = index_to_mask(idx_val, size=labels.shape[0])data.test_mask = index_to_mask(idx_test, size=labels.shape[0])def to_networkx(data):edge_index = data.edge_index.to(torch.device('cpu')).numpy()G = nx.DiGraph()for src, tar in edge_index.T:G.add_edge(src, tar)return Gnetworkx_data = to_networkx(data)return data,networkx_data
#获取数据:pyg_data:torch_geometric格式;networkx_data:networkx格式def generate_pairs(adj_matrix):# 根据邻接矩阵生成正例和负例pos_pairs = torch.nonzero(adj_matrix, as_tuple=True)pos_u = pos_pairs[0]pos_v = pos_pairs[1]mask = torch.ones_like(adj_matrix)for i in range(len(pos_u)):mask[pos_u[i]][pos_v[i]] = 0mask[pos_v[i]][pos_u[i]] = 0tmp = torch.nonzero(mask, as_tuple=True)#TODO 随机选取负例idx = torch.randperm(tmp[0].size(0))neg_u = tmp[0][idx][:pos_u.size(0)]neg_v = tmp[1][idx][:pos_v.size(0)]return pos_u, pos_v, neg_u, neg_v# 构建LINE网络
class LINE(nn.Module):def __init__(self, num_nodes, embed_dim):super(LINE, self).__init__()#num_nodes为Node个数 , embed_dim为描述Node的Embedding维度self.embed_dim = embed_dimself.num_nodes = num_nodesself.embeddings = nn.Embedding(self.num_nodes, self.embed_dim)self.reset_parameters()def reset_parameters(self):self.embeddings.weight.data.normal_(std=1 / self.embed_dim)def forward(self, pos_u, pos_v, neg_v):emb_pos_u = self.embeddings(pos_u)emb_pos_v = self.embeddings(pos_v)emb_neg_v = self.embeddings(neg_v)pos_scores = torch.sum(torch.mul(emb_pos_u, emb_pos_v), dim=1)neg_scores = torch.sum(torch.mul(emb_pos_u, emb_neg_v), dim=1)return pos_scores, neg_scoresclass LINE_run():def train(self):t = time.time()# 创建一个模型_, networkx_data = load_cora_data()adj_matrix = torch.tensor(nx.adjacency_matrix(networkx_data).toarray(), dtype=torch.float32)model = LINE(num_nodes=cfg.num_nodes, embed_dim=cfg.embedding_dim).to(cfg.device)optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)#Trainmodel.train()for epoch in range(cfg.epoch):optimizer.zero_grad()pos_u, pos_v, neg_u, neg_v = generate_pairs(adj_matrix)pos_u = pos_u.to(cfg.device)pos_v = pos_v.to(cfg.device)neg_v = neg_v.to(cfg.device)pos_scores, neg_scores = model(pos_u, pos_v, neg_v)pos_losses = cfg.criterion(pos_scores, torch.ones(len(pos_scores)).to(cfg.device))neg_losses = cfg.criterion(neg_scores, torch.zeros(len(neg_scores)).to(cfg.device))loss = pos_losses + neg_lossesloss.backward()optimizer.step()print('Epoch: {:04d}'.format(epoch + 1),'loss_train: {:.4f}'.format(loss.item()),'time: {:.4f}s'.format(time.time() - t))torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth'))  # 模型保存print('Embedding dim : ({},{})'.format(model.embeddings.weight.shape[0],model.embeddings.weight.shape[1]))def infer(self):# Create Test Processing_, networkx_data = load_cora_data()adj_matrix = torch.tensor(nx.adjacency_matrix(networkx_data).toarray(), dtype=torch.float32)model_path = os.path.join(cfg.save_model_dir, 'latest.pth')model = torch.load(model_path, map_location=torch.device(cfg.device))model.eval()_, networkx_data = load_cora_data()pos_u, pos_v, neg_u, neg_v = generate_pairs(adj_matrix)pos_u = pos_u.to(cfg.device)pos_v = pos_v.to(cfg.device)neg_v = neg_v.to(cfg.device)pos_scores, neg_scores = model(pos_u, pos_v, neg_v)pos_losses = cfg.criterion(pos_scores, torch.ones(len(pos_scores)).to(cfg.device))neg_losses = cfg.criterion(neg_scores, torch.zeros(len(neg_scores)).to(cfg.device))loss = pos_losses + neg_lossesprint("Test set results:","loss= {:.4f}".format(loss.item()),'Embedding dim : ({},{})'.format(model.embeddings.weight.shape[0], model.embeddings.weight.shape[1]))if __name__ == '__main__':mygraph = LINE_run()if cfg.istrain == True:mygraph.train()if cfg.istest == True:mygraph.infer()

三、输出结果

        跑的是Cora数据,共2708个Node,设置的Embedding维度是128维。上面代码运行完就是长下面这个样子。

Epoch: 0001 loss_train: 1.3863 time: 3.0867s
Epoch: 0002 loss_train: 1.3832 time: 3.7739s
Epoch: 0003 loss_train: 1.3768 time: 4.4471s
...
Epoch: 0028 loss_train: 0.7739 time: 21.3568s
Epoch: 0029 loss_train: 0.7694 time: 22.0310s
Epoch: 0030 loss_train: 0.7663 time: 22.7042s
Embedding dim : (2708,128)
Test set results: loss= 0.7609 Embedding dim : (2708,128)

        效果未知,没有用下游聚类测一下,反正看起来BCE loss是降了哈哈,这期就到这里。
        

相关文章:

【Graph Net学习】LINE实现Graph Embedding

一、简介 LINE (Large-scale Information Network Embedding,2015) 是一种设计用于处理大规模信息网络的算法。它主要的目标是在给定的大规模信息网络中学习高质量的节点嵌入,并尽量保留网络中信息的丰富性。其具体的表现为在一个低 维空间里以向量形式表示网络中的…...

docker安装使用xdebug

docker安装使用xdebug 1、需要先安装PHP xdebug扩展 1.1 到https://pecl.php.net/package/xdebug下载tgz文件,下载当前最新稳定版本的文件。然后把这个tgz文件放到php/extensions目录下,记得install.sh中要替换解压的文件名: installExtensio…...

(1) ESP32获取图像,并通过电脑端服务器显示图像

目录 一、所需器件工具 二、客户端与服务器进行UDP通信 1、客户端代码 2、服务器端代码 3、效果展示 三、客户端拍照,通过UDP传输到服务器进行显示 1、客户端获取图像并UDP传输 2、电脑端服务器显示图像 3、效果展示 四、代码链接 一、所需器件工具 1.ESP3…...

乐鑫科技全球首批支持蓝牙 Mesh Protocol 1.1 协议

乐鑫科技 (688018.SH) 非常高兴地宣布,其自研的蓝牙 Mesh 协议栈 ESP-BLE-MESH 现已支持最新蓝牙 Mesh Protocol 1.1 协议的全部功能,成为全球首批在蓝牙技术联盟 (Bluetooth SIG) 正式发布该协议之前支持该更新的公司之一。这意味着乐鑫在低功耗蓝牙无线…...

1.算法——数据结构学习

算法是解决特定问题求解步骤的描述。 从1加到100的结果 # include <stdio.h> int main(){ int i, sum 0, n 100; // 执行1次for(i 1; i < n; i){ // 执行n 1次sum sum i; // 执行n次} printf("%d", sum); // 执行1次return 0; }高斯求和…...

信息论基础第二章阅读笔记

信息很难用一个简单的定义准确把握。 对于任何一个概率分布&#xff0c;可以定义一个熵&#xff08;entropy&#xff09;的量&#xff0c;它具有许多特性符合度量信息的直观要求。这个概念可以推广到互信息&#xff08;mutual information&#xff09;&#xff0c;互信息是一种…...

Content-Type的取值

接口发送参数、接收响应数据&#xff0c;都需要双方约定好使用什么格式的数据&#xff0c;例如 json、xml。只有双方按照约定好的格式去解析数据才能正确的收发数据。而 Content-Type 就是用来告诉你数据的格式&#xff0c;这样我们才能知道怎么解析参数。 常见的 Content-Typ…...

【趣味JavaScript】5年前端开发都没有搞懂toString和valueOf这两个方法!

&#x1f680; 个人主页 极客小俊 ✍&#x1f3fb; 作者简介&#xff1a;web开发者、设计师、技术分享博主 &#x1f40b; 希望大家多多支持一下, 我们一起进步&#xff01;&#x1f604; &#x1f3c5; 如果文章对你有帮助的话&#xff0c;欢迎评论 &#x1f4ac;点赞&#x1…...

Python中的接口是什么?

在Python中&#xff0c;接口是一种约定或协议&#xff0c;用于定义类应该实现哪些方法或属性。接口并不会提供实际的实现&#xff0c;而是只定义了类应该具有哪些方法和属性的签名。 Python中的接口通常通过抽象基类&#xff08;Abstract Base Class&#xff0c;简称ABC&#…...

自学WEB后端01-安装Express+Node.js框架完成Hello World!

一、前言&#xff0c;网站开发扫盲知识 1.网站搭建开发包括什么&#xff1f; 前端 前端开发主要涉及用户界面&#xff08;UI&#xff09;和用户体验&#xff08;UX&#xff09;&#xff0c;负责实现网站的外观和交互逻辑。前端开发使用HTML、CSS和JavaScript等技术来构建网页…...

从C语言到C++:C++入门知识(1)

朋友们、伙计们&#xff0c;我们又见面了&#xff0c;本期来给大家解读一下有关C语言的相关知识点&#xff0c;如果看完之后对你有一定的启发&#xff0c;那么请留下你的三连&#xff0c;祝大家心想事成&#xff01; C 语 言 专 栏&#xff1a;C语言&#xff1a;从入门到精通 数…...

服务器(Windows系统)自建filebrowser网盘服务器超详细教程

需要依赖&#xff08;工具&#xff09; 轻量服务器&#xff08;云服务器&#xff09;一台 —— 环境Windows Server 2019filebrowser安装包&#xff08;https://github.com/filebrowser/filebrowser/releases&#xff09; 下载安装filebrowser 进入链接下载&#xff1a;https:/…...

扩展欧几里得

扩展欧几里得算法 求 a x b y d axbyd axbyd 的一组解&#xff0c; d gcd ⁡ ( a , b ) d \gcd(a,b) dgcd(a,b)。 辗转相除递归求解。 假设已经求出 b x ( b m o d a ) y d bx (b \bmod a)y d bx(bmoda)yd 的一组解。 a x b y b x ′ ( b m o d a ) y ′ b x …...

MySQL 事务介绍 (事务篇 一)

什么是事务&#xff1f; 事务是一组操作的集合&#xff0c;它是一个不可分割的工作单位&#xff0c;事务会把所有的操作作为一个整体一起向系统提交或撤销操作请求&#xff0c;即这些操作要么同时成功&#xff0c;要么同时失败。 注意点&#xff1a;默认MySQL的事务是自动提交…...

nvm nodejs的版本管理工具

nvm 全英文名叫 node.js version management&#xff0c;是一个 nodejs 的版本管理工具&#xff0c;为了解决 nodejs 各种版本存在不兼容现象可以通过他安装和切换不同版本的 nodejs。 一、完全删除之前的 node 和 npm 1. 打开 cmd 命令窗口&#xff0c;输入 npm cache clean…...

terraform简单的开始-vpc cvm创建

从网络开始 从创建VPC开始 复用前面的main.tf的代码&#xff1a; terraform {required_providers {tencentcloud {source "tencentcloudstack/tencentcloud"version "1.81.25"}} } variable "region" {description "腾讯云地域"…...

【MySQL】开启 canal同步MySQL增量数据到ES

开启 canal同步MySQL增量数据到ES canal 是阿里知名的开源项目&#xff0c;主要用途是基于 MySQL 数据库增量日志解析&#xff0c;提供增量数据订阅和消费。示使用 canal 将 MySQL 增量数据同步到ES。 一、集群模式 图中 server 对应一个 canal 运行实例 &#xff0c;对应一…...

密码学概论

1.密码学的三大历史阶段&#xff1a; 第一阶段 古典密码学 依赖设备&#xff0c;主要特点 数据安全基于算法的保密&#xff0c;算法不公开&#xff0c;只要破译算法 密文就会被破解&#xff0c; 在1883年第一次提出 加密算法应该基于算法公开 不影响密文和秘钥的安全&#xff…...

渗透测试中的前端调试(一)

前言 前端调试是安全测试的重要组成部分。它能够帮助我们掌握网页的运行原理&#xff0c;包括js脚本的逻辑、加解密的方法、网络请求的参数等。利用这些信息&#xff0c;我们就可以更准确地发现网站的漏洞&#xff0c;制定出有效的攻击策略。前端知识对于安全来说&#xff0c;…...

SPA项目之登录注册--请求问题(POSTGET)以及跨域问题

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于VueElementUI的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一.ElementUI是什么 &#x1f4a1;…...

大学生职业发展与就业创业指导教学评价

这里是引用 作为软工2203/2204班的学生&#xff0c;我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要&#xff0c;而您认真负责的教学态度&#xff0c;让课程的每一部分都充满了实用价值。 尤其让我…...

使用 SymPy 进行向量和矩阵的高级操作

在科学计算和工程领域&#xff0c;向量和矩阵操作是解决问题的核心技能之一。Python 的 SymPy 库提供了强大的符号计算功能&#xff0c;能够高效地处理向量和矩阵的各种操作。本文将深入探讨如何使用 SymPy 进行向量和矩阵的创建、合并以及维度拓展等操作&#xff0c;并通过具体…...

OPENCV形态学基础之二腐蚀

一.腐蚀的原理 (图1) 数学表达式&#xff1a;dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一&#xff0c;腐蚀跟膨胀属于反向操作&#xff0c;膨胀是把图像图像变大&#xff0c;而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...

短视频矩阵系统文案创作功能开发实践,定制化开发

在短视频行业迅猛发展的当下&#xff0c;企业和个人创作者为了扩大影响力、提升传播效果&#xff0c;纷纷采用短视频矩阵运营策略&#xff0c;同时管理多个平台、多个账号的内容发布。然而&#xff0c;频繁的文案创作需求让运营者疲于应对&#xff0c;如何高效产出高质量文案成…...

动态 Web 开发技术入门篇

一、HTTP 协议核心 1.1 HTTP 基础 协议全称 &#xff1a;HyperText Transfer Protocol&#xff08;超文本传输协议&#xff09; 默认端口 &#xff1a;HTTP 使用 80 端口&#xff0c;HTTPS 使用 443 端口。 请求方法 &#xff1a; GET &#xff1a;用于获取资源&#xff0c;…...

TSN交换机正在重构工业网络,PROFINET和EtherCAT会被取代吗?

在工业自动化持续演进的今天&#xff0c;通信网络的角色正变得愈发关键。 2025年6月6日&#xff0c;为期三天的华南国际工业博览会在深圳国际会展中心&#xff08;宝安&#xff09;圆满落幕。作为国内工业通信领域的技术型企业&#xff0c;光路科技&#xff08;Fiberroad&…...

HTML前端开发:JavaScript 获取元素方法详解

作为前端开发者&#xff0c;高效获取 DOM 元素是必备技能。以下是 JS 中核心的获取元素方法&#xff0c;分为两大系列&#xff1a; 一、getElementBy... 系列 传统方法&#xff0c;直接通过 DOM 接口访问&#xff0c;返回动态集合&#xff08;元素变化会实时更新&#xff09;。…...

绕过 Xcode?使用 Appuploader和主流工具实现 iOS 上架自动化

iOS 应用的发布流程一直是开发链路中最“苹果味”的环节&#xff1a;强依赖 Xcode、必须使用 macOS、各种证书和描述文件配置……对很多跨平台开发者来说&#xff0c;这一套流程并不友好。 特别是当你的项目主要在 Windows 或 Linux 下开发&#xff08;例如 Flutter、React Na…...

UE5 音效系统

一.音效管理 音乐一般都是WAV,创建一个背景音乐类SoudClass,一个音效类SoundClass。所有的音乐都分为这两个类。再创建一个总音乐类&#xff0c;将上述两个作为它的子类。 接着我们创建一个音乐混合类SoundMix&#xff0c;将上述三个类翻入其中&#xff0c;通过它管理每个音乐…...

免费批量Markdown转Word工具

免费批量Markdown转Word工具 一款简单易用的批量Markdown文档转换工具&#xff0c;支持将多个Markdown文件一键转换为Word文档。完全免费&#xff0c;无需安装&#xff0c;解压即用&#xff01; 官方网站 访问官方展示页面了解更多信息&#xff1a;http://mutou888.com/pro…...