当前位置: 首页 > news >正文

batch_softmax_loss

每个用户抽取一定数量的困难负样本,然后ssm

    def batch_softmax_loss_neg(self, user_idx, rec_user_emb, pos_idx, item_emb):user_emb = rec_user_emb[user_idx]product_scores = torch.matmul(F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1).transpose(0, 1))pos_score = (rec_user_emb[user_idx] * item_emb[pos_idx]).sum(dim=-1)pos_score = torch.exp(pos_score / self.temp2)train_mask = self.data.ui_adj[user_idx, self.data.user_num:].toarray()train_mask = torch.tensor(train_mask).cuda()product_scores = product_scores * (1 - train_mask)neg_score, indices = product_scores.topk(500, dim=1, largest=True, sorted=True)neg_score = torch.exp(neg_score[:,400:] / self.temp2).sum(dim=-1)loss = -torch.log(pos_score / (pos_score + neg_score + 10e-6))return torch.mean(loss)
def batch_softmax_loss_neg(user_emb, pos_item_emb, neg_item_emb, temperature):user_emb, pos_item_emb, neg_item_emb = F.normalize(user_emb, dim=1), F.normalize(pos_item_emb, dim=1), F.normalize(neg_item_emb, dim=1)pos_score = (user_emb * pos_item_emb).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)user_emb = user_emb.unsqueeze(1).expand(user_emb.shape[0],neg_item_emb.shape[1],user_emb.shape[1])neg_score = (user_emb * neg_item_emb).sum(dim=-1) # user_emb(n*1*d) neg_item_emb = (n*m*d)neg_score = torch.exp(neg_score / temperature).sum(dim=-1)loss = -torch.log(pos_score / (pos_score + neg_score + 10e-6))return torch.mean(loss)

均匀性损失(错误案例)

# def cal_uniform_loss(user_emb, item_emb):
#     user_emb, item_emb = F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1)
#     distance = user_emb - item_emb  # n*d
#     gaussian_potential = torch.exp(-2 * torch.norm(distance,p=2,dim=1))
#     E_gaussian_potential = gaussian_potential.mean()
#     return torch.log(E_gaussian_potential)

DNS

def DNSbpr(user_emb, pos_item_emb, neg_item_emb):pos_score = torch.mul(user_emb, pos_item_emb).sum(dim=1)user_emb = user_emb.unsqueeze(1).expand(user_emb.shape[0], neg_item_emb.shape[1], user_emb.shape[1])ttl_socre = (user_emb * neg_item_emb).sum(dim=-1)neg_score = torch.max(ttl_socre, dim=1).valuesloss = -torch.log(10e-6 + torch.sigmoid(pos_score - neg_score))return torch.mean(loss)

带margin的infonce

def InfoNCE_margin(view1, view2, temperature, margin, b_cos = True):if b_cos:view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)pos_score = (view1 * view2).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)margin = margin * torch.eye(view1.shape[0])ttl_score = torch.matmul(view1, view2.transpose(0, 1))ttl_score += margin.cuda(0)ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)cl_loss = -torch.log(pos_score / ttl_score+10e-6)return torch.mean(cl_loss)def InfoNCE_tau(view1, view2, temperature):view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)pos_score = (view1 * view2).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)ttl_score = torch.matmul(view1, view2.transpose(0, 1))ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)cl_loss = -torch.log(pos_score / ttl_score+10e-6)return torch.mean(cl_loss)def batch_bpr_loss(user_emb, item_emb):pos_score = torch.mul(user_emb, item_emb).sum(dim=1)neg_score = torch.matmul(user_emb, item_emb.transpose(0, 1)).mean(dim=1)loss = -torch.log(10e-6 + torch.sigmoid(pos_score - neg_score))return torch.mean(loss)def Dis_softmax(view1, view2, temperature, b_cos = True):if b_cos:view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)N,M = view1.shapepos_score = (view1 - view2).norm(p=2, dim=1)pos_score = torch.exp(pos_score / temperature)view1 = view1.unsqueeze(1).expand(N,N,M)view2 = view2.unsqueeze(0).expand(N,N,M)ttl_score = (view1 - view2).norm(p=2, dim=-1)ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)cl_loss = torch.log(pos_score / ttl_score+10e-6)return torch.mean(cl_loss)

LightGCN+对比学习

    def forward(self, perturbed=False):ego_embeddings = torch.cat([self.embedding_dict['user_emb'], self.embedding_dict['item_emb']], 0)all_embeddings = []all_embeddings_cl = ego_embeddingsfor k in range(self.n_layers):ego_embeddings = torch.sparse.mm(self.sparse_norm_adj, ego_embeddings)if perturbed:random_noise = torch.rand_like(ego_embeddings).cuda()ego_embeddings += torch.sign(ego_embeddings) * F.normalize(random_noise, dim=-1) * self.epsall_embeddings.append(ego_embeddings)if k==self.layer_cl-1:all_embeddings_cl +=  F.normalize(all_embeddings[1]-all_embeddings[0], dim=-1) * self.epsfinal_embeddings = torch.stack(all_embeddings, dim=1)final_embeddings = torch.mean(final_embeddings, dim=1)user_all_embeddings, item_all_embeddings = torch.split(final_embeddings, [self.data.user_num, self.data.item_num])user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embeddings_cl, [self.data.user_num, self.data.item_num])if perturbed:return user_all_embeddings, item_all_embeddings,user_all_embeddings_cl, item_all_embeddings_clreturn user_all_embeddings, item_all_embeddings
    def train(self):model = self.model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=self.lRate)hot_uidx, hot_iidx = self.select_ui_idx(500, mode='hot')cold_uidx, cold_iidx = self.select_ui_idx(500, mode='cold')norm_uidx, norm_iidx = self.select_ui_idx(500, mode='norm')iters = 10alphas_init = torch.tensor([1, 2], dtype=torch.float64).to(device)betas_init = torch.tensor([2, 1], dtype=torch.float64).to(device)weights_init = torch.tensor([1 - 0.05, 0.05], dtype=torch.float64).to(device)for epoch in range(self.maxEpoch):# epoch_rec_loss = []bmm_model = BetaMixture1D(iters, alphas_init, betas_init, weights_init)rec_user_emb, rec_item_emb, cl_user_emb, cl_item_emb = model(True)self.bmm_fit(rec_user_emb, rec_item_emb,torch.arange(self.data.user_num),np.random.randint(0,self.data.item_num, 100),bmm_model)for n, batch in enumerate(next_batch_pairwise(self.data, self.batch_size)):user_idx, pos_idx, rec_neg_idx = batchrec_user_emb, rec_item_emb, cl_user_emb, cl_item_emb = model(True)user_emb, pos_item_emb= rec_user_emb[user_idx], rec_item_emb[pos_idx]# rec_loss = self.batch_softmax_loss_neg(user_idx, rec_user_emb, pos_idx, rec_item_emb)# rec_neg_idx = torch.tensor(rec_neg_idx,dtype=torch.int64)# rec_neg_item_emb = rec_item_emb[rec_neg_idx]weight = self.getWeightSim(user_emb, pos_item_emb, bmm_model)rec_loss = weighted_SSM(user_emb,pos_item_emb,self.temp2,weight)cl_loss =  self.cl_rate * self.cal_cl_loss([user_idx,pos_idx],rec_user_emb,cl_user_emb,rec_item_emb,cl_item_emb)batch_loss =  rec_loss + l2_reg_loss(self.reg, user_emb, pos_item_emb) + cl_loss# epoch_rec_loss.append(rec_loss.item()), epoch_cl_loss.append(cl_loss.item())# Backward and optimizeoptimizer.zero_grad()batch_loss.backward()optimizer.step()if n % 100==0 and n>0:print('training:', epoch + 1, 'batch', n, 'rec_loss:', rec_loss.item(), 'cl_loss', cl_loss.item())with torch.no_grad():self.user_emb, self.item_emb = self.model()hot_emb = torch.cat([self.user_emb[hot_uidx],self.item_emb[hot_iidx]],0)cold_emb = torch.cat([self.user_emb[cold_uidx],self.item_emb[cold_iidx]],0)self.eval_uniform(epoch, hot_emb, cold_emb)hot_user_mag = self.cal_sim(epoch, hot_uidx, self.user_emb, self.item_emb,mode='hot')self.cal_sim(epoch, norm_uidx, self.user_emb, self.item_emb, mode='norm')cold_user_mag= self.cal_sim(epoch, cold_uidx, self.user_emb, self.item_emb, mode='cold')hot_item_mag = self.item_magnitude(epoch, hot_iidx, self.item_emb,mode='hot')self.item_magnitude(epoch, norm_iidx, self.item_emb, mode='norm')cold_item_mag = self.item_magnitude(epoch, cold_iidx, self.item_emb, mode='cold')print('training:',epoch + 1,'U_mag_ratio:',hot_user_mag/cold_user_mag, 'I_mag_ratio:',hot_item_mag/cold_item_mag)# self.getTopSimNeg(hot_uidx, self.user_emb,self.item_emb, 100)# self.getTopSimNeg(norm_uidx,self.user_emb,self.item_emb, 100)# self.getTopSimNeg(cold_uidx,self.user_emb,self.item_emb, 100)# epoch_rec_loss = np.array(epoch_rec_loss).mean()# self.loss.extend([epoch_rec_loss,epoch_cl_loss,hot_pair_uniform_loss.item(),random_item_uniform_loss.item()])# if epoch%5==0:#     self.save_emb(epoch, hot_emb, mode='hot')#     self.save_emb(epoch, random_emb, mode='random')self.fast_evaluation(epoch)# self.save_loss()self.user_emb, self.item_emb = self.best_user_emb, self.best_item_emb# self.save_emb(self.bestPerformance[0], hot_emb, mode='best_hot')# self.save_emb(self.bestPerformance[0], random_emb, mode='best_random')

hard_neg buffer

    def getHardNeg(self, user_idx, pos_idx, rec_user_emb, rec_item_emb,temperature):u_emb,i_emb = F.normalize(rec_user_emb[user_idx], dim=1),F.normalize(rec_item_emb[pos_idx], dim=1)pos_score =  (u_emb * i_emb).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)i_emb = i_emb.unsqueeze(0).expand(u_emb.size(0), -1, -1)neg_idx = torch.LongTensor(pos_idx).unsqueeze(0).expand(u_emb.size(0), -1).to(device)# if torch.all(self.hardNeg[user_idx])!=0:#     preNeg = self.hardNeg[user_idx]#     preNeg_emb = F.normalize(rec_item_emb[preNeg], dim=1)#     neg_idx = torch.cat([neg_idx,preNeg],1)#     i_emb = torch.cat([i_emb, preNeg_emb],1)ttl_score = (u_emb.unsqueeze(1) * i_emb).sum(dim=-1)indices = torch.topk(ttl_score, k=100)[1].detach()ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)rec_loss = -torch.log(pos_score / ttl_score + 10e-6)chosen_hardNeg = neg_idx[torch.arange(i_emb.size(0)).unsqueeze(1), indices]self.hardNeg[user_idx] = chosen_hardNegreturn torch.mean(rec_loss)

相关文章:

batch_softmax_loss

每个用户抽取一定数量的困难负样本,然后ssm def batch_softmax_loss_neg(self, user_idx, rec_user_emb, pos_idx, item_emb):user_emb rec_user_emb[user_idx]product_scores torch.matmul(F.normalize(user_emb, dim1), F.normalize(item_emb, dim1).transpose(…...

刘汉清:从生活到画布,宠物成为灵感源泉

出生于中国镇江的艺术家刘汉清,其作品展现出他对日常生活的深入洞察力,以及对美的独特理解。他的作品通常没有视觉参考,而是通过对他周围环境的理解,尤其是他的宠物,来进行创作。 在刘汉清的创作过程中,他…...

【LeetCode】240.搜索二维矩阵Ⅱ

题目 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性: 每行的元素从左到右升序排列。每列的元素从上到下升序排列。 示例 1: 输入:matrix [[1,4,7,11,15],[2,5,8,12,19],[3,6,9,16,22],[10,13,…...

SED正则表达式中[方括号]的特殊处理

今天被这个方括号懵晕了,特此记录 例如: 去除输入字符串“1[2.3]4[ab,c]”中的所有方括号和逗号: $ echo "1[2.3]4[ab,c]"|sed -e "s/[,\]\[]//g" 1[2.3]4[ab,c] It doesnt work! 原因:Regular Expressi…...

Android 音频开发

在Android平台上进行音频开发,您需要掌握以下关键知识点: Android平台基础知识:熟悉Android操作系统的基本架构、组件和应用开发的基本概念。 音频API:了解Android提供的音频相关API,主要包括android.media.AudioReco…...

Java8新特性,Lambda,Stream流

Java8新特性,Lambda,Stream流 Java8版本在2014年3月18日发布,为Java语言添加了很多重要的新特性。新特性包括:Lambda表达式、方法引用、默认方法、新的时间日期API、Stream API、Optional类等等。这些新特性大大增强了Java语言的表达能力,使…...

模型训练之train.py代码解析

题目 作者:安静到无声 个人主页 from __future__ import absolute_import from __future__ import division from __future__ import print_function这段代码使用了Python 2.x的__future__模块来导入Python 3.x的一些特性。在Python 2.x中,使用print语句来输出内容,而在Py…...

linux 复习

vim 使用 一般模式 、 命令模式、编辑模式 esc 进入一般模式 i 进入编辑模式 shift: 进入命令模式 yy p 复制粘贴 5yy 复制当前开始的5行 dd 删除 5dd 删除当前开始的5行 u撤销操作 ctrlr 恢复 shiftg 滚动最底部 gg 滚动最顶 输入数字 然后shiftg 跳转到指定行 用户操作…...

C语言刷题------(2)

C语言刷题——————(2) 刷题网站:题库 - 蓝桥云课 (lanqiao.cn) First Question:时间显示 题目描述 小蓝要和朋友合作开发一个时间显示的网站。 在服务器上,朋友已经获取了当前的时间,用一个整数表…...

JVM 之 OopMap 和 RememberedSet

前几天看周志明的《深入 Java 虚拟机》,感觉对 OopMap 和 RememberedSet 的介绍,看起来不太容易理解清楚。今天查了一些资料,并结合自己的一些猜想,把对这两种数据结构的理解写出来。目的只是为了简单易懂,而且多有推测…...

Original error: gsmCall method is only available for emulators

在夜神模拟器执行报错 self.driver.make_gsm_call(5551234567, GsmCallActions.CALL)意思是gsmCall这个命令不支持,只支持下面这些命令 selenium.common.exceptions.UnknownMethodException: Message: Unknown mobile command "gsmCall". Only shell,exe…...

React Native从文本内容尾部截取显示省略号

<Textstyle{styles.mMeNickname}ellipsizeMode"tail"numberOfLines{1}>{userInfo.nickname}</Text> 参考链接&#xff1a; https://www.reactnative.cn/docs/text#ellipsizemode https://chat.xutongbao.top/...

机器学习笔记之优化算法(十一)凸函数铺垫:梯度与方向导数

机器学习笔记之优化算法——凸函数铺垫&#xff1a;梯度与方向导数 引言回顾&#xff1a;偏导数方向余弦方向导数方向导数的几何意义方向导数的定义 方向导数与偏导数之间的关联关系证明过程 梯度 ( Gradient ) (\text{Gradient}) (Gradient) 引言 本节作为介绍凸函数的铺垫&a…...

探究Vue源码:mustache模板引擎(11) 递归处理循环逻辑并收尾算法处理

好 在上文 探究Vue源码:mustache模板引擎(10) 解决不能用连续点符号找到多层对象问题&#xff0c;为编译循环结构做铺垫 我们解决了js字符串没办法通过 什么点什么拿到对象中的值的问题 这个大家需要记住 因为这个方法的编写之前是当做面试题出现过的 那么 本文 我们就要去写上…...

STM32 CubeMX USB_CDC(USB_转串口)

STM32 CubeMX STM32 CubeMX 定时器&#xff08;普通模式和PWM模式&#xff09; STM32 CubeMX一、STM32 CubeMX 设置USB时钟设置USB使能UBS功能选择 二、代码部分添加代码实验效果 ![请添加图片描述](https://img-blog.csdnimg.cn/a7333bba478441ab950a66fc63f204fb.png)printf发…...

机器学习——卷积神经网络基础

卷积神经网络&#xff08;Convolutional Neural Network&#xff1a;CNN&#xff09; 卷积神经网络是人工神经网络的一种&#xff0c;是一种前馈神经网络。最早提出时的灵感来源于人类的神经元。 通俗来讲&#xff0c;其主要的操作就是&#xff1a;接受输入层的输入信息&…...

端到端自动驾驶前沿论文盘点(pdf+代码)

现在的自动驾驶&#xff0c;大多数还是采用的模块化架构&#xff0c;但这种架构的缺陷十分明显&#xff1a;在一个自动驾驶系统里&#xff0c;可能会包含很多个模型&#xff0c;每个模型都要专门进行训练、优化、迭代&#xff0c;随着模型的不断进化&#xff0c;参数量不断提高…...

2023年中期奶粉行业分析报告(京东数据开放平台)

根据国家统计局和民政部数据公布&#xff0c;2022年中国结婚登记数创造了1980年&#xff08;有数据公布&#xff09;以来的历史新低&#xff0c;共计683.3万对。相较于2013年巅峰时期的数据&#xff0c;2022年全国结婚登记对数已接近“腰斩”。 2023年“520”期间的结婚登记数…...

web集群学习:基于CentOS 7构建 LVS-DR 群集并配置服务启动脚本

目录 1、环境准备 2、配置lvs服务启动脚本 1、在RS上分别配置服务启动脚本 2、在lvs director上配置服务启动脚本 3、客户端测试 配置LVS-DR模式主要注意的有 1、vip绑定在RS的lo接口&#xff1b; 2、RS做arp抑制&#xff1b; 1、环境准备 VIP192.168.95.10 RS1192.168…...

Flask 高级应用:使用蓝图模块化应用和 JWT 实现安全认证

本文将探讨 Flask 的两个高级特性&#xff1a;蓝图&#xff08;Blueprints&#xff09;和 JSON Web Token&#xff08;JWT&#xff09;认证。蓝图让我们可以将应用模块化&#xff0c;以便更好地组织代码&#xff1b;而 JWT 认证是现代 Web 应用中常见的一种安全机制。 一、使用…...

Python|GIF 解析与构建(5):手搓截屏和帧率控制

目录 Python&#xff5c;GIF 解析与构建&#xff08;5&#xff09;&#xff1a;手搓截屏和帧率控制 一、引言 二、技术实现&#xff1a;手搓截屏模块 2.1 核心原理 2.2 代码解析&#xff1a;ScreenshotData类 2.2.1 截图函数&#xff1a;capture_screen 三、技术实现&…...

java_网络服务相关_gateway_nacos_feign区别联系

1. spring-cloud-starter-gateway 作用&#xff1a;作为微服务架构的网关&#xff0c;统一入口&#xff0c;处理所有外部请求。 核心能力&#xff1a; 路由转发&#xff08;基于路径、服务名等&#xff09;过滤器&#xff08;鉴权、限流、日志、Header 处理&#xff09;支持负…...

树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法

树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源&#xff1a; http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作&#xff0c;无需更改相机配置。但是&#xff0c;一…...

Xshell远程连接Kali(默认 | 私钥)Note版

前言:xshell远程连接&#xff0c;私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...

Day131 | 灵神 | 回溯算法 | 子集型 子集

Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; 笔者写过很多次这道题了&#xff0c;不想写题解了&#xff0c;大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...

2024年赣州旅游投资集团社会招聘笔试真

2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...

Qt Http Server模块功能及架构

Qt Http Server 是 Qt 6.0 中引入的一个新模块&#xff0c;它提供了一个轻量级的 HTTP 服务器实现&#xff0c;主要用于构建基于 HTTP 的应用程序和服务。 功能介绍&#xff1a; 主要功能 HTTP服务器功能&#xff1a; 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...

相机从app启动流程

一、流程框架图 二、具体流程分析 1、得到cameralist和对应的静态信息 目录如下: 重点代码分析: 启动相机前,先要通过getCameraIdList获取camera的个数以及id,然后可以通过getCameraCharacteristics获取对应id camera的capabilities(静态信息)进行一些openCamera前的…...

WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)

一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解&#xff0c;适合用作学习或写简历项目背景说明。 &#x1f9e0; 一、概念简介&#xff1a;Solidity 合约开发 Solidity 是一种专门为 以太坊&#xff08;Ethereum&#xff09;平台编写智能合约的高级编…...

uniapp手机号一键登录保姆级教程(包含前端和后端)

目录 前置条件创建uniapp项目并关联uniClound云空间开启一键登录模块并开通一键登录服务编写云函数并上传部署获取手机号流程(第一种) 前端直接调用云函数获取手机号&#xff08;第三种&#xff09;后台调用云函数获取手机号 错误码常见问题 前置条件 手机安装有sim卡手机开启…...