当前位置: 首页 > news >正文

【Pytorch】Fizz Buzz

在这里插入图片描述

文章目录

  • 1 数据编码
  • 2 网络搭建
  • 3 网络配置,训练
  • 4 结果预测
  • 5 翻车现场

学习参考来自:

  • Fizz Buzz in Tensorflow
  • https://github.com/wmn7/ML_Practice/tree/master/2019_06_10
  • Fizz Buzz in Pytorch

I need you to print the numbers from 1 to 100, except that if the number is divisible by 3 print “fizz”, if it’s divisible by 5 print “buzz”, and if it’s divisible by 15 print “fizzbuzz”.

编程题很简单,我们用 MLP 实现试试

思路,训练集数据101~1024,对其进行某种规则的编码,标签为经分类 one-hot 编码后的标签
测试集,1~100

don’t say so much, show me the code.

1 数据编码

import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Datadef binary_encode(i, num_digits):"""将每个input转换为binary digits(转换为二进制的表示, 最多可是表示2^num_digits):param i::param num_digits::return:"""return np.array([i >> d & 1 for d in range(num_digits)])

编码形式,依次除以 2 0 , 1 , 2 , 3 , . . . 2^{0,1,2,3,...} 20,1,2,3,...,结果按位与 1

m & 1,结果为 0 表示 m 为偶数, 结果为 1 表示 m 为奇数

> > m >> m >>m 右移表示除以 2 m 2^m 2m

第一位就能表示奇偶了,所有数字编码都不一样

eg,101 进行 num_digits=10 编码后结果为 1 0 1 0 0 1 1 0 0 0

步骤

101 / 1 = 101 奇数 1
101 / 2 = 50 偶数 0
101 / 4 = 25 奇数 1
101 / 8 = 12 偶数 0
101 / 16 = 6 偶数 0
101 / 32 = 3 奇数 1
101 / 64 = 1 奇数 1
101 / 128 = 0 偶数 0
101 / 256= 0 偶数 0
101 / 512= 0 偶数 0

标签,0,1,2,3 四个类别

def fizz_buzz_encode(i):"""将output转换为lebel:param i::return:"""if i % 15 == 0:  # fizzbuzzreturn 3elif i % 5 == 0:  # buzzreturn 2elif i % 3 == 0:  # fizzreturn 1else:return 0

编码长度设定,数据集 101 ~ 1024

NUM_DIGITS = 10
trX = np.array([binary_encode(i, NUM_DIGITS) for i in range(101, 2**NUM_DIGITS)])  # 101~1024
trY = np.array([fizz_buzz_encode(i) for i in range(101, 2**NUM_DIGITS)])# print(len(trX), len(trY))  # 923 923
# print(trX[:5])
"""
[[1 0 1 0 0 1 1 0 0 0][0 1 1 0 0 1 1 0 0 0][1 1 1 0 0 1 1 0 0 0][0 0 0 1 0 1 1 0 0 0][1 0 0 1 0 1 1 0 0 0]]
"""
# print(trY[:5])  # [0 1 0 0 3]

2 网络搭建

搭建简单的 MLP 网络

class FizzBuzzModel(nn.Module):def __init__(self, in_features, out_classes, hidden_size, n_hidden_layers):super(FizzBuzzModel,self).__init__()layers = []for i in range(n_hidden_layers):layers.append(nn.Linear(hidden_size,hidden_size))# layers.append(nn.Dropout(0.5))layers.append(nn.BatchNorm1d(hidden_size))layers.append(nn.ReLU())self.inputLayer = nn.Linear(in_features, hidden_size)self.relu = nn.ReLU()self.layers = nn.Sequential(*layers)  # 重复的搭建隐藏层self.outputLayer = nn.Linear(hidden_size, out_classes)def forward(self, x):x = self.inputLayer(x)x = self.relu(x)x = self.layers(x)out = self.outputLayer(x)return out

初始化网络,看看网络结构

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# define the model
simpleModel = FizzBuzzModel(NUM_DIGITS, 4, 150, 3).to(device)
print(simpleModel)
"""
FizzBuzzModel((inputLayer): Linear(in_features=10, out_features=150, bias=True)(relu): ReLU()(layers): Sequential((0): Linear(in_features=150, out_features=150, bias=True)(1): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU()(3): Linear(in_features=150, out_features=150, bias=True)(4): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU()(6): Linear(in_features=150, out_features=150, bias=True)(7): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(8): ReLU())(outputLayer): Linear(in_features=150, out_features=4, bias=True)
)
"""

输入 10, 输出4,隐藏层维度 150,隐藏层重复了 3 次

3 网络配置,训练

定义下超参数,损失函数,优化器,载入数据训练,输出训练精度与损失

# Loss and optimizer
learning_rate = 0.05
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(simpleModel.parameters(), lr=learning_rate)
optimizer = torch.optim.SGD(simpleModel.parameters(), lr=learning_rate)# 使用batch进行训练
FizzBuzzDataset = Data.TensorDataset(torch.from_numpy(trX).float().to(device),torch.from_numpy(trY).long().to(device))loader = Data.DataLoader(dataset=FizzBuzzDataset,batch_size=128*5,shuffle=True)# 进行训练
simpleModel.train()
epochs = 3000for epoch in range(1, epochs):for step, (batch_x, batch_y) in enumerate(loader):out = simpleModel(batch_x)  # 前向传播loss = criterion(out, batch_y)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播optimizer.step()  # 随机梯度下降correct = 0total = 0_, predicted = torch.max(out.data, 1)total += batch_y.size(0)correct += (predicted == batch_y).sum().item()acc = 100*correct/totalprint('Epoch : {:0>4d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, loss, acc))"""
Epoch : 0001 | Loss : 1.5343 | Train Accuracy : 14.63 %
Epoch : 0002 | Loss : 1.9779 | Train Accuracy : 42.58 %
Epoch : 0003 | Loss : 2.4198 | Train Accuracy : 53.41 %
Epoch : 0004 | Loss : 1.7360 | Train Accuracy : 53.41 %
Epoch : 0005 | Loss : 1.3161 | Train Accuracy : 49.73 %
Epoch : 0006 | Loss : 1.4866 | Train Accuracy : 22.75 %
Epoch : 0007 | Loss : 1.3993 | Train Accuracy : 25.57 %
Epoch : 0008 | Loss : 1.2428 | Train Accuracy : 28.49 %
Epoch : 0009 | Loss : 1.1906 | Train Accuracy : 44.31 %
Epoch : 0010 | Loss : 1.1929 | Train Accuracy : 52.44 %
...
Epoch : 2990 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2991 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2992 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2993 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2994 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2995 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2996 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2997 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2998 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2999 | Loss : 0.0000 | Train Accuracy : 100.00%
"""

训练集上精度是 OK 的,能到 100%,下面看看测试集上的精度

4 结果预测

把 one-hot 标签转化成 fizz buzz 的形式

def fizz_buzz_decode(i, prediction):return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

载入测试集,开始预测

simpleModel.eval()
# 进行预测
testX = np.array([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
predicts = simpleModel(torch.from_numpy(testX).float().to(device))
# 预测的结果
_, res = torch.max(predicts, 1)
print(res)
"""
tensor([0, 0, 0, 1, 0, 0, 0, 2, 1, 0, 1, 3, 3, 1, 1, 0, 0, 0, 0, 0, 0, 3, 1, 0,0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 0, 1, 1, 1, 0,0, 0, 0, 0], device='cuda:0')
"""# 格式的转换
predictions = [fizz_buzz_decode(i, prediction) for (i, prediction) in zip(range(1, 101), res)]
print(predictions)
"""
['1', '2', '3', 'fizz', '5', '6', '7', 'buzz', 'fizz', '10', 'fizz', 'fizzbuzz', 'fizzbuzz', 'fizz', 'fizz', '16', '17', '18', '19', '20', '21', 'fizzbuzz', 'fizz', '24', '25', '26', '27', '28', '29', '30', 'fizz', '32', '33', '34', '35', '36', '37', '38', '39', 'fizz', '41', 'fizz', '43', '44', '45', 'fizz', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', 'fizz', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', 'fizz', 'fizz', 'fizz', 'fizz', 'buzz', 'buzz', 'fizz', '80', '81', '82', '83', '84', '85', '86', '87', 'fizzbuzz', '89', '90', 'fizz', '92', 'fizz', 'fizz', 'fizz', '96', '97', '98', '99', '100']
"""

5 翻车现场

对比下标签

labels = []
for i in range(1, 101):if i % 15 == 0:  # fizzbuzzlabels.append("fizzbuzz")elif i % 5 == 0:  # buzzlabels.append("buzz")elif i % 3 == 0:  # fizzlabels.append("fizz")else:labels.append(str(i))
print(labels)
print(labels == predictions)"""
['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', '19', 'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', '38', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', 'fizz', '67', '68', 'fizz', 'buzz', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'buzz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', 'buzz']
False
"""

哈哈哈, False 翻车了,尝试了很多次,很难 True

相关文章:

【Pytorch】Fizz Buzz

文章目录 1 数据编码2 网络搭建3 网络配置&#xff0c;训练4 结果预测5 翻车现场 学习参考来自&#xff1a; Fizz Buzz in Tensorflowhttps://github.com/wmn7/ML_Practice/tree/master/2019_06_10Fizz Buzz in Pytorch I need you to print the numbers from 1 to 100, excep…...

C++ Primer Plus第十四章笔记

目录 1.包含对象成员的类 valarray类简介 1.2 Student类的设计 1.3 接口和实现 1.4 C和约束 2. 私有继承 2.1 私有继承和组合的异同 2.2 初始化基类组件 2.3 访问基类的方法 2.4 访问基类对象 2.5 访问基类的友元函数 2.5 使用组合还是私有继承 3. 保护继承 4. 使…...

CentOS 7 mini 运行环境搭建与测试——CentOS Mini 安装ifconfig工具【云原生开发部署实践笔记】

云原生开发部署实践笔记 一、开发测试环境搭建与测试 1.1 Linux运行环境的搭建与测试 虽然CentOS已经更新到Stream 9 版本&#xff0c;但基于大多数企业和单位多数使用CentOS 7版本作为运行底座&#xff0c;7版本也一直在更行维护&#xff0c;此实践基于CentOS 7 Mini版本搭…...

案例061:基于微信小程序的互助学习系统

文末获取源码 开发语言&#xff1a;Java 框架&#xff1a;SSM JDK版本&#xff1a;JDK1.8 数据库&#xff1a;mysql 5.7 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.5.4 小程序框架&#xff1a;uniapp 小程序开发软件&#xff1a;HBuilder X 小程序…...

【ELK03】ES 索引的Mapping映射详解、数据类型和settings属性设置

一、ES 索引的映射和设置 1.MAPPING 映射(MAPPING)就是es中一个决定了文档如何存储,如何生成索引,字段各种类型定义的过程.类似于我们在关系型数据库中创建一个表格数据之前先定义表格有哪些字段,每个字段是什么类型,然后数据会按照这个配置写入表格,ES中同样是这个过程,它由…...

线性代数入门与学习笔记

该内容为重拾部分线性代数知识的学习笔记&#xff0c;内容上更多的是为了解决问题而学习的内容&#xff0c;并非系统化的学习。 针对的问题为&#xff1a;Music算法推导求解过程中的矩阵计算知识。 学习的内容包括&#xff1a;矩阵原理、矩阵行列式、矩阵的秩、线性变换矩阵变换…...

Linux安全学习路标

1. 操作系统基础知识 首先&#xff0c;你需要建立坚实的操作系统基础知识&#xff0c;包括Linux文件系统和目录结构、Linux进程管理、权限管理等基本概念。 2. 网络和通信安全 学习关于网络和通信安全的基础知识&#xff0c;包括TCP/IP协议栈、网络攻击类型、防火墙配置、网…...

常见的中间件--消息队列中间件测试点

最近刷题&#xff0c;看到了有问中间件的题目&#xff0c;于是整理了一些中间件的知识&#xff0c;大多是在小破站上的笔记&#xff0c;仅供大家参考~ 主要分为七个部分来分享&#xff1a; 一、常见的中间件 二、什么是队列&#xff1f; 三、常见消息队列MQ的比较 四、队列…...

【USRP】5G / 6G OAI 系统 5g / 6G OAI system

面向5G/6G科研应用 USRP专门用于5G/6G产品的原型开发与验证。该系统可以在实验室搭建一个真实的5G 网络&#xff0c;基于开源的代码&#xff0c;专为科研用户设计。 软件无线电架构&#xff0c;构建真实5G移动通信系统 X410 采用了目前流行的异构式系统&#xff0c;融合了FP…...

ubuntu20.04设置开机自启动jar(依赖其他服务)

目的&#xff1a; 有的时候我们的项目是部署在物理机上给其他公司员工使用&#xff0c;对于他们来说操作越简单越好。所以我需要实现将我的jar部署在ubuntu上&#xff0c;实现开机自启。&#xff08;我的项目依赖emqx服务&#xff09;。 步骤&#xff1a; 切换到system目录 …...

【GEE笔记】在线分类流程,标注样本点、分类和精度评价

GEE在线分类流程 介绍 GEE&#xff08;Google Earth Engine&#xff09;是一个强大的地理信息处理平台&#xff0c;可以实现在线的遥感影像分析和处理。本文将介绍如何使用GEE进行在线的分类流程&#xff0c;包括标注样本点、分类和精度评价。本文以2020年5月至8月的哨兵2影像…...

MATLAB基础运算

矩阵和数字相乘 就是矩阵里面每个元素跟这个数字乘一遍 矩阵和矩阵相乘 能不能相乘&#xff0c;需要前面矩阵的列数等于后面矩阵的行数&#xff0c;出来的矩阵大小是前面矩阵的行数*后面矩阵的列数。 所以大家会发现&#xff0c;矩阵相乘如果前后调转了&#xff0c;结果会完全…...

Linux DAC权限的简单应用

Linux的DAC&#xff08;Discretionary Access Control&#xff09;权限模型是一种常见的访问控制机制&#xff0c;它用于管理文件和目录的访问权限。作为一名经验丰富的Linux系统安全工程师&#xff0c;我会尽可能以简单明了的方式向计算机小白介绍Linux DAC权限模型。 在Linu…...

JVS低代码表单引擎:数据校验与处理的先锋

随着信息技术的迅速发展&#xff0c;数据校验与处理已经成为了各类应用中不可或缺的一环。尤其是在涉及敏感信息&#xff0c;如密码处理时&#xff0c;其安全性和准确性显得尤为重要。JVS低代码表单引擎提供了强大的文本组件触发逻辑校验功能&#xff0c;它能够在用户填写数据的…...

clickhouse删除partition分区数据

clickhouse分布式表tencent_table_20231208_DIST&#xff0c;本地表tencent_table_20231208_local&#xff1b; 30台clickhouse存储服务器&#xff1b; 本地表&#xff1a;tencent_table_20231208_local CREATE TABLE tencent_sz.tencent_table_20231208_local (id Int64 DEFA…...

持续集成交付CICD:CentOS 7 安装 Nexus 3.63

目录 一、实验 1.CentOS 7 安装Nexus3.63 二、问题 1.安装Nexus报错 2.Nexus启动停止相关命令 一、实验 1.CentOS 7 安装Nexus3.63 &#xff08;1&#xff09;当前操作系统版本&JDK版本 cat /etc/redhat-releasejava -version&#xff08;2&#xff09;下载Nexus新…...

Apache Flink(十):Flink集群基础环境搭建-JDK及MySQL搭建

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹哥教你大数据个人主页-哔哩哔哩视频 目录...

LVS-DR+Keepalived+动静分离实验

架构图 解释一下架构&#xff0c;大概就是用Keepalived实现两台DR服务器的LVS负载均衡&#xff0c;然后后端服务器是两台Nginx服务器两台Tomcat服务器并且实现动静分离这个实验其实就是把 LVS-DRKeepalived 和 动静分离 给拼起来&#xff0c;真的是拼起来&#xff0c;两个部分…...

java面试题-Hashmap、Hashtable、ConcurrentHashMap原理

远离八股文&#xff0c;面试大白话&#xff0c;通俗且易懂 看完后试着用自己的话复述出来。有问题请指出&#xff0c;有需要帮助理解的或者遇到的真实面试题不知道怎么总结的也请评论中写出来&#xff0c;大家一起解决。 java面试题汇总-目录-持续更新中 Hashmap和hashtable存储…...

数据可视化:解锁企业经营的智慧之道

在现代企业管理中&#xff0c;数据可视化已经成为了一项重要的工具。它不仅仅是简单地展示数据&#xff0c;更是提供了深入理解数据、做出更明智决策的方法。作为一名可视化设计从业人员&#xff0c;我经手过一些企业自用的数据可视化项目&#xff0c;今天就来和大家聊聊数据可…...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销&#xff0c;平衡网络负载&#xff0c;延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

边缘计算医疗风险自查APP开发方案

核心目标:在便携设备(智能手表/家用检测仪)部署轻量化疾病预测模型,实现低延迟、隐私安全的实时健康风险评估。 一、技术架构设计 #mermaid-svg-iuNaeeLK2YoFKfao {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg…...

Python实现prophet 理论及参数优化

文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候&#xff0c;写过一篇简单实现&#xff0c;后期随着对该模型的深入研究&#xff0c;本次记录涉及到prophet 的公式以及参数调优&#xff0c;从公式可以更直观…...

JS设计模式(4):观察者模式

JS设计模式(4):观察者模式 一、引入 在开发中&#xff0c;我们经常会遇到这样的场景&#xff1a;一个对象的状态变化需要自动通知其他对象&#xff0c;比如&#xff1a; 电商平台中&#xff0c;商品库存变化时需要通知所有订阅该商品的用户&#xff1b;新闻网站中&#xff0…...

VM虚拟机网络配置(ubuntu24桥接模式):配置静态IP

编辑-虚拟网络编辑器-更改设置 选择桥接模式&#xff0c;然后找到相应的网卡&#xff08;可以查看自己本机的网络连接&#xff09; windows连接的网络点击查看属性 编辑虚拟机设置更改网络配置&#xff0c;选择刚才配置的桥接模式 静态ip设置&#xff1a; 我用的ubuntu24桌…...

音视频——I2S 协议详解

I2S 协议详解 I2S (Inter-IC Sound) 协议是一种串行总线协议&#xff0c;专门用于在数字音频设备之间传输数字音频数据。它由飞利浦&#xff08;Philips&#xff09;公司开发&#xff0c;以其简单、高效和广泛的兼容性而闻名。 1. 信号线 I2S 协议通常使用三根或四根信号线&a…...

Caliper 配置文件解析:fisco-bcos.json

config.yaml 文件 config.yaml 是 Caliper 的主配置文件,通常包含以下内容: test:name: fisco-bcos-test # 测试名称description: Performance test of FISCO-BCOS # 测试描述workers:type: local # 工作进程类型number: 5 # 工作进程数量monitor:type: - docker- pro…...

HybridVLA——让单一LLM同时具备扩散和自回归动作预测能力:训练时既扩散也回归,但推理时则扩散

前言 如上一篇文章《dexcap升级版之DexWild》中的前言部分所说&#xff0c;在叠衣服的过程中&#xff0c;我会带着团队对比各种模型、方法、策略&#xff0c;毕竟针对各个场景始终寻找更优的解决方案&#xff0c;是我个人和我司「七月在线」的职责之一 且个人认为&#xff0c…...

Pydantic + Function Calling的结合

1、Pydantic Pydantic 是一个 Python 库&#xff0c;用于数据验证和设置管理&#xff0c;通过 Python 类型注解强制执行数据类型。它广泛用于 API 开发&#xff08;如 FastAPI&#xff09;、配置管理和数据解析&#xff0c;核心功能包括&#xff1a; 数据验证&#xff1a;通过…...

ui框架-文件列表展示

ui框架-文件列表展示 介绍 UI框架的文件列表展示组件&#xff0c;可以展示文件夹&#xff0c;支持列表展示和图标展示模式。组件提供了丰富的功能和可配置选项&#xff0c;适用于文件管理、文件上传等场景。 功能特性 支持列表模式和网格模式的切换展示支持文件和文件夹的层…...