当前位置: 首页 > news >正文

batch_softmax_loss

每个用户抽取一定数量的困难负样本,然后ssm

    def batch_softmax_loss_neg(self, user_idx, rec_user_emb, pos_idx, item_emb):user_emb = rec_user_emb[user_idx]product_scores = torch.matmul(F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1).transpose(0, 1))pos_score = (rec_user_emb[user_idx] * item_emb[pos_idx]).sum(dim=-1)pos_score = torch.exp(pos_score / self.temp2)train_mask = self.data.ui_adj[user_idx, self.data.user_num:].toarray()train_mask = torch.tensor(train_mask).cuda()product_scores = product_scores * (1 - train_mask)neg_score, indices = product_scores.topk(500, dim=1, largest=True, sorted=True)neg_score = torch.exp(neg_score[:,400:] / self.temp2).sum(dim=-1)loss = -torch.log(pos_score / (pos_score + neg_score + 10e-6))return torch.mean(loss)
def batch_softmax_loss_neg(user_emb, pos_item_emb, neg_item_emb, temperature):user_emb, pos_item_emb, neg_item_emb = F.normalize(user_emb, dim=1), F.normalize(pos_item_emb, dim=1), F.normalize(neg_item_emb, dim=1)pos_score = (user_emb * pos_item_emb).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)user_emb = user_emb.unsqueeze(1).expand(user_emb.shape[0],neg_item_emb.shape[1],user_emb.shape[1])neg_score = (user_emb * neg_item_emb).sum(dim=-1) # user_emb(n*1*d) neg_item_emb = (n*m*d)neg_score = torch.exp(neg_score / temperature).sum(dim=-1)loss = -torch.log(pos_score / (pos_score + neg_score + 10e-6))return torch.mean(loss)

均匀性损失(错误案例)

# def cal_uniform_loss(user_emb, item_emb):
#     user_emb, item_emb = F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1)
#     distance = user_emb - item_emb  # n*d
#     gaussian_potential = torch.exp(-2 * torch.norm(distance,p=2,dim=1))
#     E_gaussian_potential = gaussian_potential.mean()
#     return torch.log(E_gaussian_potential)

DNS

def DNSbpr(user_emb, pos_item_emb, neg_item_emb):pos_score = torch.mul(user_emb, pos_item_emb).sum(dim=1)user_emb = user_emb.unsqueeze(1).expand(user_emb.shape[0], neg_item_emb.shape[1], user_emb.shape[1])ttl_socre = (user_emb * neg_item_emb).sum(dim=-1)neg_score = torch.max(ttl_socre, dim=1).valuesloss = -torch.log(10e-6 + torch.sigmoid(pos_score - neg_score))return torch.mean(loss)

带margin的infonce

def InfoNCE_margin(view1, view2, temperature, margin, b_cos = True):if b_cos:view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)pos_score = (view1 * view2).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)margin = margin * torch.eye(view1.shape[0])ttl_score = torch.matmul(view1, view2.transpose(0, 1))ttl_score += margin.cuda(0)ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)cl_loss = -torch.log(pos_score / ttl_score+10e-6)return torch.mean(cl_loss)def InfoNCE_tau(view1, view2, temperature):view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)pos_score = (view1 * view2).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)ttl_score = torch.matmul(view1, view2.transpose(0, 1))ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)cl_loss = -torch.log(pos_score / ttl_score+10e-6)return torch.mean(cl_loss)def batch_bpr_loss(user_emb, item_emb):pos_score = torch.mul(user_emb, item_emb).sum(dim=1)neg_score = torch.matmul(user_emb, item_emb.transpose(0, 1)).mean(dim=1)loss = -torch.log(10e-6 + torch.sigmoid(pos_score - neg_score))return torch.mean(loss)def Dis_softmax(view1, view2, temperature, b_cos = True):if b_cos:view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)N,M = view1.shapepos_score = (view1 - view2).norm(p=2, dim=1)pos_score = torch.exp(pos_score / temperature)view1 = view1.unsqueeze(1).expand(N,N,M)view2 = view2.unsqueeze(0).expand(N,N,M)ttl_score = (view1 - view2).norm(p=2, dim=-1)ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)cl_loss = torch.log(pos_score / ttl_score+10e-6)return torch.mean(cl_loss)

LightGCN+对比学习

    def forward(self, perturbed=False):ego_embeddings = torch.cat([self.embedding_dict['user_emb'], self.embedding_dict['item_emb']], 0)all_embeddings = []all_embeddings_cl = ego_embeddingsfor k in range(self.n_layers):ego_embeddings = torch.sparse.mm(self.sparse_norm_adj, ego_embeddings)if perturbed:random_noise = torch.rand_like(ego_embeddings).cuda()ego_embeddings += torch.sign(ego_embeddings) * F.normalize(random_noise, dim=-1) * self.epsall_embeddings.append(ego_embeddings)if k==self.layer_cl-1:all_embeddings_cl +=  F.normalize(all_embeddings[1]-all_embeddings[0], dim=-1) * self.epsfinal_embeddings = torch.stack(all_embeddings, dim=1)final_embeddings = torch.mean(final_embeddings, dim=1)user_all_embeddings, item_all_embeddings = torch.split(final_embeddings, [self.data.user_num, self.data.item_num])user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embeddings_cl, [self.data.user_num, self.data.item_num])if perturbed:return user_all_embeddings, item_all_embeddings,user_all_embeddings_cl, item_all_embeddings_clreturn user_all_embeddings, item_all_embeddings
    def train(self):model = self.model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=self.lRate)hot_uidx, hot_iidx = self.select_ui_idx(500, mode='hot')cold_uidx, cold_iidx = self.select_ui_idx(500, mode='cold')norm_uidx, norm_iidx = self.select_ui_idx(500, mode='norm')iters = 10alphas_init = torch.tensor([1, 2], dtype=torch.float64).to(device)betas_init = torch.tensor([2, 1], dtype=torch.float64).to(device)weights_init = torch.tensor([1 - 0.05, 0.05], dtype=torch.float64).to(device)for epoch in range(self.maxEpoch):# epoch_rec_loss = []bmm_model = BetaMixture1D(iters, alphas_init, betas_init, weights_init)rec_user_emb, rec_item_emb, cl_user_emb, cl_item_emb = model(True)self.bmm_fit(rec_user_emb, rec_item_emb,torch.arange(self.data.user_num),np.random.randint(0,self.data.item_num, 100),bmm_model)for n, batch in enumerate(next_batch_pairwise(self.data, self.batch_size)):user_idx, pos_idx, rec_neg_idx = batchrec_user_emb, rec_item_emb, cl_user_emb, cl_item_emb = model(True)user_emb, pos_item_emb= rec_user_emb[user_idx], rec_item_emb[pos_idx]# rec_loss = self.batch_softmax_loss_neg(user_idx, rec_user_emb, pos_idx, rec_item_emb)# rec_neg_idx = torch.tensor(rec_neg_idx,dtype=torch.int64)# rec_neg_item_emb = rec_item_emb[rec_neg_idx]weight = self.getWeightSim(user_emb, pos_item_emb, bmm_model)rec_loss = weighted_SSM(user_emb,pos_item_emb,self.temp2,weight)cl_loss =  self.cl_rate * self.cal_cl_loss([user_idx,pos_idx],rec_user_emb,cl_user_emb,rec_item_emb,cl_item_emb)batch_loss =  rec_loss + l2_reg_loss(self.reg, user_emb, pos_item_emb) + cl_loss# epoch_rec_loss.append(rec_loss.item()), epoch_cl_loss.append(cl_loss.item())# Backward and optimizeoptimizer.zero_grad()batch_loss.backward()optimizer.step()if n % 100==0 and n>0:print('training:', epoch + 1, 'batch', n, 'rec_loss:', rec_loss.item(), 'cl_loss', cl_loss.item())with torch.no_grad():self.user_emb, self.item_emb = self.model()hot_emb = torch.cat([self.user_emb[hot_uidx],self.item_emb[hot_iidx]],0)cold_emb = torch.cat([self.user_emb[cold_uidx],self.item_emb[cold_iidx]],0)self.eval_uniform(epoch, hot_emb, cold_emb)hot_user_mag = self.cal_sim(epoch, hot_uidx, self.user_emb, self.item_emb,mode='hot')self.cal_sim(epoch, norm_uidx, self.user_emb, self.item_emb, mode='norm')cold_user_mag= self.cal_sim(epoch, cold_uidx, self.user_emb, self.item_emb, mode='cold')hot_item_mag = self.item_magnitude(epoch, hot_iidx, self.item_emb,mode='hot')self.item_magnitude(epoch, norm_iidx, self.item_emb, mode='norm')cold_item_mag = self.item_magnitude(epoch, cold_iidx, self.item_emb, mode='cold')print('training:',epoch + 1,'U_mag_ratio:',hot_user_mag/cold_user_mag, 'I_mag_ratio:',hot_item_mag/cold_item_mag)# self.getTopSimNeg(hot_uidx, self.user_emb,self.item_emb, 100)# self.getTopSimNeg(norm_uidx,self.user_emb,self.item_emb, 100)# self.getTopSimNeg(cold_uidx,self.user_emb,self.item_emb, 100)# epoch_rec_loss = np.array(epoch_rec_loss).mean()# self.loss.extend([epoch_rec_loss,epoch_cl_loss,hot_pair_uniform_loss.item(),random_item_uniform_loss.item()])# if epoch%5==0:#     self.save_emb(epoch, hot_emb, mode='hot')#     self.save_emb(epoch, random_emb, mode='random')self.fast_evaluation(epoch)# self.save_loss()self.user_emb, self.item_emb = self.best_user_emb, self.best_item_emb# self.save_emb(self.bestPerformance[0], hot_emb, mode='best_hot')# self.save_emb(self.bestPerformance[0], random_emb, mode='best_random')

hard_neg buffer

    def getHardNeg(self, user_idx, pos_idx, rec_user_emb, rec_item_emb,temperature):u_emb,i_emb = F.normalize(rec_user_emb[user_idx], dim=1),F.normalize(rec_item_emb[pos_idx], dim=1)pos_score =  (u_emb * i_emb).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)i_emb = i_emb.unsqueeze(0).expand(u_emb.size(0), -1, -1)neg_idx = torch.LongTensor(pos_idx).unsqueeze(0).expand(u_emb.size(0), -1).to(device)# if torch.all(self.hardNeg[user_idx])!=0:#     preNeg = self.hardNeg[user_idx]#     preNeg_emb = F.normalize(rec_item_emb[preNeg], dim=1)#     neg_idx = torch.cat([neg_idx,preNeg],1)#     i_emb = torch.cat([i_emb, preNeg_emb],1)ttl_score = (u_emb.unsqueeze(1) * i_emb).sum(dim=-1)indices = torch.topk(ttl_score, k=100)[1].detach()ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)rec_loss = -torch.log(pos_score / ttl_score + 10e-6)chosen_hardNeg = neg_idx[torch.arange(i_emb.size(0)).unsqueeze(1), indices]self.hardNeg[user_idx] = chosen_hardNegreturn torch.mean(rec_loss)

相关文章:

batch_softmax_loss

每个用户抽取一定数量的困难负样本,然后ssm def batch_softmax_loss_neg(self, user_idx, rec_user_emb, pos_idx, item_emb):user_emb rec_user_emb[user_idx]product_scores torch.matmul(F.normalize(user_emb, dim1), F.normalize(item_emb, dim1).transpose(…...

刘汉清:从生活到画布,宠物成为灵感源泉

出生于中国镇江的艺术家刘汉清,其作品展现出他对日常生活的深入洞察力,以及对美的独特理解。他的作品通常没有视觉参考,而是通过对他周围环境的理解,尤其是他的宠物,来进行创作。 在刘汉清的创作过程中,他…...

【LeetCode】240.搜索二维矩阵Ⅱ

题目 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性: 每行的元素从左到右升序排列。每列的元素从上到下升序排列。 示例 1: 输入:matrix [[1,4,7,11,15],[2,5,8,12,19],[3,6,9,16,22],[10,13,…...

SED正则表达式中[方括号]的特殊处理

今天被这个方括号懵晕了,特此记录 例如: 去除输入字符串“1[2.3]4[ab,c]”中的所有方括号和逗号: $ echo "1[2.3]4[ab,c]"|sed -e "s/[,\]\[]//g" 1[2.3]4[ab,c] It doesnt work! 原因:Regular Expressi…...

Android 音频开发

在Android平台上进行音频开发,您需要掌握以下关键知识点: Android平台基础知识:熟悉Android操作系统的基本架构、组件和应用开发的基本概念。 音频API:了解Android提供的音频相关API,主要包括android.media.AudioReco…...

Java8新特性,Lambda,Stream流

Java8新特性,Lambda,Stream流 Java8版本在2014年3月18日发布,为Java语言添加了很多重要的新特性。新特性包括:Lambda表达式、方法引用、默认方法、新的时间日期API、Stream API、Optional类等等。这些新特性大大增强了Java语言的表达能力,使…...

模型训练之train.py代码解析

题目 作者:安静到无声 个人主页 from __future__ import absolute_import from __future__ import division from __future__ import print_function这段代码使用了Python 2.x的__future__模块来导入Python 3.x的一些特性。在Python 2.x中,使用print语句来输出内容,而在Py…...

linux 复习

vim 使用 一般模式 、 命令模式、编辑模式 esc 进入一般模式 i 进入编辑模式 shift: 进入命令模式 yy p 复制粘贴 5yy 复制当前开始的5行 dd 删除 5dd 删除当前开始的5行 u撤销操作 ctrlr 恢复 shiftg 滚动最底部 gg 滚动最顶 输入数字 然后shiftg 跳转到指定行 用户操作…...

C语言刷题------(2)

C语言刷题——————(2) 刷题网站:题库 - 蓝桥云课 (lanqiao.cn) First Question:时间显示 题目描述 小蓝要和朋友合作开发一个时间显示的网站。 在服务器上,朋友已经获取了当前的时间,用一个整数表…...

JVM 之 OopMap 和 RememberedSet

前几天看周志明的《深入 Java 虚拟机》,感觉对 OopMap 和 RememberedSet 的介绍,看起来不太容易理解清楚。今天查了一些资料,并结合自己的一些猜想,把对这两种数据结构的理解写出来。目的只是为了简单易懂,而且多有推测…...

Original error: gsmCall method is only available for emulators

在夜神模拟器执行报错 self.driver.make_gsm_call(5551234567, GsmCallActions.CALL)意思是gsmCall这个命令不支持,只支持下面这些命令 selenium.common.exceptions.UnknownMethodException: Message: Unknown mobile command "gsmCall". Only shell,exe…...

React Native从文本内容尾部截取显示省略号

<Textstyle{styles.mMeNickname}ellipsizeMode"tail"numberOfLines{1}>{userInfo.nickname}</Text> 参考链接&#xff1a; https://www.reactnative.cn/docs/text#ellipsizemode https://chat.xutongbao.top/...

机器学习笔记之优化算法(十一)凸函数铺垫:梯度与方向导数

机器学习笔记之优化算法——凸函数铺垫&#xff1a;梯度与方向导数 引言回顾&#xff1a;偏导数方向余弦方向导数方向导数的几何意义方向导数的定义 方向导数与偏导数之间的关联关系证明过程 梯度 ( Gradient ) (\text{Gradient}) (Gradient) 引言 本节作为介绍凸函数的铺垫&a…...

探究Vue源码:mustache模板引擎(11) 递归处理循环逻辑并收尾算法处理

好 在上文 探究Vue源码:mustache模板引擎(10) 解决不能用连续点符号找到多层对象问题&#xff0c;为编译循环结构做铺垫 我们解决了js字符串没办法通过 什么点什么拿到对象中的值的问题 这个大家需要记住 因为这个方法的编写之前是当做面试题出现过的 那么 本文 我们就要去写上…...

STM32 CubeMX USB_CDC(USB_转串口)

STM32 CubeMX STM32 CubeMX 定时器&#xff08;普通模式和PWM模式&#xff09; STM32 CubeMX一、STM32 CubeMX 设置USB时钟设置USB使能UBS功能选择 二、代码部分添加代码实验效果 ![请添加图片描述](https://img-blog.csdnimg.cn/a7333bba478441ab950a66fc63f204fb.png)printf发…...

机器学习——卷积神经网络基础

卷积神经网络&#xff08;Convolutional Neural Network&#xff1a;CNN&#xff09; 卷积神经网络是人工神经网络的一种&#xff0c;是一种前馈神经网络。最早提出时的灵感来源于人类的神经元。 通俗来讲&#xff0c;其主要的操作就是&#xff1a;接受输入层的输入信息&…...

端到端自动驾驶前沿论文盘点(pdf+代码)

现在的自动驾驶&#xff0c;大多数还是采用的模块化架构&#xff0c;但这种架构的缺陷十分明显&#xff1a;在一个自动驾驶系统里&#xff0c;可能会包含很多个模型&#xff0c;每个模型都要专门进行训练、优化、迭代&#xff0c;随着模型的不断进化&#xff0c;参数量不断提高…...

2023年中期奶粉行业分析报告(京东数据开放平台)

根据国家统计局和民政部数据公布&#xff0c;2022年中国结婚登记数创造了1980年&#xff08;有数据公布&#xff09;以来的历史新低&#xff0c;共计683.3万对。相较于2013年巅峰时期的数据&#xff0c;2022年全国结婚登记对数已接近“腰斩”。 2023年“520”期间的结婚登记数…...

web集群学习:基于CentOS 7构建 LVS-DR 群集并配置服务启动脚本

目录 1、环境准备 2、配置lvs服务启动脚本 1、在RS上分别配置服务启动脚本 2、在lvs director上配置服务启动脚本 3、客户端测试 配置LVS-DR模式主要注意的有 1、vip绑定在RS的lo接口&#xff1b; 2、RS做arp抑制&#xff1b; 1、环境准备 VIP192.168.95.10 RS1192.168…...

Flask 高级应用:使用蓝图模块化应用和 JWT 实现安全认证

本文将探讨 Flask 的两个高级特性&#xff1a;蓝图&#xff08;Blueprints&#xff09;和 JSON Web Token&#xff08;JWT&#xff09;认证。蓝图让我们可以将应用模块化&#xff0c;以便更好地组织代码&#xff1b;而 JWT 认证是现代 Web 应用中常见的一种安全机制。 一、使用…...

HTML 语义化

目录 HTML 语义化HTML5 新特性HTML 语义化的好处语义化标签的使用场景最佳实践 HTML 语义化 HTML5 新特性 标准答案&#xff1a; 语义化标签&#xff1a; <header>&#xff1a;页头<nav>&#xff1a;导航<main>&#xff1a;主要内容<article>&#x…...

零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?

一、核心优势&#xff1a;专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发&#xff0c;是一款收费低廉但功能全面的Windows NAS工具&#xff0c;主打“无学习成本部署” 。与其他NAS软件相比&#xff0c;其优势在于&#xff1a; 无需硬件改造&#xff1a;将任意W…...

通过Wrangler CLI在worker中创建数据库和表

官方使用文档&#xff1a;Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后&#xff0c;会在本地和远程创建数据库&#xff1a; npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库&#xff1a; 现在&#xff0c;您的Cloudfla…...

无法与IP建立连接,未能下载VSCode服务器

如题&#xff0c;在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈&#xff0c;发现是VSCode版本自动更新惹的祸&#xff01;&#xff01;&#xff01; 在VSCode的帮助->关于这里发现前几天VSCode自动更新了&#xff0c;我的版本号变成了1.100.3 才导致了远程连接出…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

&#x1f31f; 什么是 MCP&#xff1f; 模型控制协议 (MCP) 是一种创新的协议&#xff0c;旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议&#xff0c;它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

【HTML-16】深入理解HTML中的块元素与行内元素

HTML元素根据其显示特性可以分为两大类&#xff1a;块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

是否存在路径(FIFOBB算法)

题目描述 一个具有 n 个顶点e条边的无向图&#xff0c;该图顶点的编号依次为0到n-1且不存在顶点与自身相连的边。请使用FIFOBB算法编写程序&#xff0c;确定是否存在从顶点 source到顶点 destination的路径。 输入 第一行两个整数&#xff0c;分别表示n 和 e 的值&#xff08;1…...

初学 pytest 记录

安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...

面向无人机海岸带生态系统监测的语义分割基准数据集

描述&#xff1a;海岸带生态系统的监测是维护生态平衡和可持续发展的重要任务。语义分割技术在遥感影像中的应用为海岸带生态系统的精准监测提供了有效手段。然而&#xff0c;目前该领域仍面临一个挑战&#xff0c;即缺乏公开的专门面向海岸带生态系统的语义分割基准数据集。受…...

基于Springboot+Vue的办公管理系统

角色&#xff1a; 管理员、员工 技术&#xff1a; 后端: SpringBoot, Vue2, MySQL, Mybatis-Plus 前端: Vue2, Element-UI, Axios, Echarts, Vue-Router 核心功能&#xff1a; 该办公管理系统是一个综合性的企业内部管理平台&#xff0c;旨在提升企业运营效率和员工管理水…...