当前位置: 首页 > news >正文

PyTorch 基础篇(1):Pytorch 基础

Pytorch 学习开始
入门的材料来自两个地方:

第一个是官网教程:WELCOME TO PYTORCH TUTORIALS,特别是官网的六十分钟入门教程 DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ。

第二个是韩国大神 Yunjey Choi 的 Repo:pytorch-tutorial,代码写得干净整洁。

目的:我是直接把 Yunjey 的教程的 python 代码挪到 Jupyter Notebook 上来,一方面可以看到运行结果,另一方面可以添加注释和相关资料链接。方便后面查阅。

顺便一题,我的 Pytorch 的版本是 0.4.1

  
  1. import torch
  2. print(torch.version)
  
  1. 0.4.1
  
  1. # 包
  2. import torch
  3. import torchvision
  4. import torch.nn as nn
  5. import numpy as np
  6. import torchvision.transforms as transforms

autograd(自动求导 / 求梯度) 基础案例 1

  
  1. # 创建张量(tensors)
  2. x = torch.tensor(1., requires_grad=True)
  3. w = torch.tensor(2., requires_grad=True)
  4. b = torch.tensor(3., requires_grad=True)
  5.  
  6. # 构建计算图( computational graph):前向计算
  7. y = w * x + b # y = 2 * x + 3
  8.  
  9. # 反向传播,计算梯度(gradients)
  10. y.backward()
  11.  
  12. # 输出梯度
  13. print(x.grad) # x.grad = 2
  14. print(w.grad) # w.grad = 1
  15. print(b.grad) # b.grad = 1
  
  1. tensor(2.)
  2. tensor(1.)
  3. tensor(1.)

autograd(自动求导 / 求梯度) 基础案例 2

  
  1. # 创建大小为 (10, 3) 和 (10, 2)的张量.
  2. x = torch.randn(10, 3)
  3. y = torch.randn(10, 2)
  4.  
  5. # 构建全连接层(fully connected layer)
  6. linear = nn.Linear(3, 2)
  7. print ('w: ', linear.weight)
  8. print ('b: ', linear.bias)
  9.  
  10. # 构建损失函数和优化器(loss function and optimizer)
  11. # 损失函数使用均方差
  12. # 优化器使用随机梯度下降,lr是learning rate
  13. criterion = nn.MSELoss()
  14. optimizer = torch.optim.SGD(linear.parameters(), lr=0.01)
  15.  
  16. # 前向传播
  17. pred = linear(x)
  18.  
  19. # 计算损失
  20. loss = criterion(pred, y)
  21. print('loss: ', loss.item())
  22.  
  23. # 反向传播
  24. loss.backward()
  25.  
  26. # 输出梯度
  27. print ('dL/dw: ', linear.weight.grad)
  28. print ('dL/db: ', linear.bias.grad)
  29.  
  30. # 执行一步-梯度下降(1-step gradient descent)
  31. optimizer.step()
  32.  
  33. # 更底层的实现方式是这样子的
  34. # linear.weight.data.sub_(0.01 * linear.weight.grad.data)
  35. # linear.bias.data.sub_(0.01 * linear.bias.grad.data)
  36.  
  37. # 进行一次梯度下降之后,输出新的预测损失
  38. # loss的确变少了
  39. pred = linear(x)
  40. loss = criterion(pred, y)
  41. print(‘loss after 1 step optimization: ‘, loss.item())
  
  1. w: Parameter containing:
  2. tensor([[ 0.5180, 0.2238, -0.5470],
  3. [ 0.1531, 0.2152, -0.4022]], requires_grad=True)
  4. b: Parameter containing:
  5. tensor([-0.2110, -0.2629], requires_grad=True)
  6. loss: 0.8057981729507446
  7. dL/dw: tensor([[-0.0315, 0.1169, -0.8623],
  8. [ 0.4858, 0.5005, -0.0223]])
  9. dL/db: tensor([0.1065, 0.0955])
  10. loss after 1 step optimization: 0.7932316660881042

从 Numpy 装载数据

  
  1. # 创建Numpy数组
  2. x = np.array([[1, 2], [3, 4]])
  3. print(x)
  4.  
  5. # 将numpy数组转换为torch的张量
  6. y = torch.from_numpy(x)
  7. print(y)
  8.  
  9. # 将torch的张量转换为numpy数组
  10. z = y.numpy()
  11. print(z)
  
  1. [[1 2]
  2. [3 4]]
  3. tensor([[1, 2],
  4. [3, 4]])
  5. [[1 2]
  6. [3 4]]

输入工作流(Input pipeline)

  
  1. # 下载和构造CIFAR-10 数据集
  2. # Cifar-10数据集介绍:https://www.cs.toronto.edu/~kriz/cifar.html
  3. train_dataset = torchvision.datasets.CIFAR10(root=’…/…/…/data/’,
  4. train=True,
  5. transform=transforms.ToTensor(),
  6. download=True)
  7.  
  8. # 获取一组数据对(从磁盘中读取)
  9. image, label = train_dataset[0]
  10. print (image.size())
  11. print (label)
  12.  
  13. # 数据加载器(提供了队列和线程的简单实现)
  14. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  15. batch_size=64,
  16. shuffle=True)
  17.  
  18. # 迭代的使用
  19. # 当迭代开始时,队列和线程开始从文件中加载数据
  20. data_iter = iter(train_loader)
  21.  
  22. # 获取一组mini-batch
  23. images, labels = data_iter.next()
  24.  
  25.  
  26. # 正常的使用方式如下:
  27. for images, labels in train_loader:
  28. # 在此处添加训练用的代码
  29. pass
  
  1. Files already downloaded and verified
  2. torch.Size([3, 32, 32])
  3. 6

自定义数据集的 Input pipeline

  
  1. # 构建自定义数据集的方式如下:
  2. class CustomDataset(torch.utils.data.Dataset):
  3. def init(self):
  4. # TODO
  5. # 1. 初始化文件路径或者文件名
  6. pass
  7. def getitem(self, index):
  8. # TODO
  9. # 1. 从文件中读取一份数据(比如使用nump.fromfile,PIL.Image.open)
  10. # 2. 预处理数据(比如使用 torchvision.Transform)
  11. # 3. 返回数据对(比如 image和label)
  12. pass
  13. def len(self):
  14. # 将0替换成数据集的总长度
  15. return 0
  16. # 然后就可以使用预置的数据加载器(data loader)了
  17. custom_dataset = CustomDataset()
  18. train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
  19. batch_size=64,
  20. shuffle=True)
  21.  
  22. 预训练模型
  
  1. # 下载并加载预训练好的模型 ResNet-18
  2. resnet = torchvision.models.resnet18(pretrained=True)
  3.  
  4.  
  5. # 如果想要在模型仅对Top Layer进行微调的话,可以设置如下:
  6. # requieres_grad设置为False的话,就不会进行梯度更新,就能保持原有的参数
  7. for param in resnet.parameters():
  8. param.requires_grad = False
  9. # 替换TopLayer,只对这一层做微调
  10. resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is an example.
  11.  
  12. # 前向传播
  13. images = torch.randn(64, 3, 224, 224)
  14. outputs = resnet(images)
  15. print (outputs.size()) # (64, 100)
  
  1. torch.Size([64, 100])

保存和加载模型

  
  1. # 保存和加载整个模型
  2. torch.save(resnet, ‘model.ckpt’)
  3. model = torch.load(‘model.ckpt’)
  4.  
  5. # 仅保存和加载模型的参数(推荐这个方式)
  6. torch.save(resnet.state_dict(), ‘params.ckpt’)
  7. resnet.load_state_dict(torch.load(‘params.ckpt’))

 

相关文章:

PyTorch 基础篇(1):Pytorch 基础

Pytorch 学习开始 入门的材料来自两个地方: 第一个是官网教程:WELCOME TO PYTORCH TUTORIALS,特别是官网的六十分钟入门教程 DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ。 第二个是韩国大神 Yunjey Choi 的 Repo:pytorch-t…...

掌握Selenium4:详解各种定位方式

Selenium4中有多种元素定位方式,主要包括以下几种: 通过ID属性定位:根据元素的id属性进行定位。通过name属性定位:当元素没有id属性而有name属性时,可以使用name属性进行元素定位。通过class name定位:可以…...

go-fastfds部署心得

我是windows系统安装 Docker Desktop部署 docker run --name go-fastdfs(任意的一个名称) --privilegedtrue -t -p 3666:8080 -v /data/fasttdfs_data:/data -e GO_FASTDFS_DIR/data sjqzhang/go-fastdfs:lastest docker run:该命令用于运…...

Python第三次练习

Python 一、如何判断一个字符串是否是另一个字符串的子串二、如何验证一个字符串中的每一个字符均在另一个字符串中出现三、如何判定一个字符串中既有数字又有字母四、做一个注册登录系统 一、如何判断一个字符串是否是另一个字符串的子串 实现代码: string1 inp…...

从Java8升级到Java17,特色优化点

从Java8升级到Java17,特色优化点 一、局部变量类型推断二、switch表达式三、文本块四、Records五、模式匹配instanceof六、密封类七、NullPointerException 从Java 8 到 Java 20,Java 已经走过了漫长的道路,自 Java 8 以来,Java 生…...

js实现富文本

当涉及到使用 JavaScript 实现富文本时,一种常见的方法是使用一些现成的富文本编辑器库,比如: Quill:一个功能强大、易于集成的富文本编辑器,支持自定义样式和格式,提供丰富的插件和API。 TinyMCE&#xf…...

每日OJ题_算法_双指针②_力扣1089. 复写零

目录 力扣1089. 复写零 解析代码 力扣1089. 复写零 1089. 复写零 - 力扣(LeetCode) 难度 简单 给你一个长度固定的整数数组 arr ,请你将该数组中出现的每个零都复写一遍,并将其余的元素向右平移。 注意:请不要在…...

C++——红黑树

作者:几冬雪来 时间:2023年12月7日 内容:C——红黑树讲解 目录 前言: 红黑树的概念: 红黑树的性质: 红黑树的路径计算: 最长路径和最短路径: AVL树与红黑树的区别&#xff…...

【神化世界】asp网页500内部服务器错误的解决方法

问题解决方案记录 一、问题 在asp网页调试的时候,不小心改错了,好好的页面突然出现如下错误信息了: 二、解决方法 终于找到了问题所在,是sql语句出错造成的,特别记录一下。 正确的写法 sql"select * from mem…...

java面试题6

1.什么是Java中的泛型(Generic)? 答案:泛型是一种参数化类型的机制,在编译时提供类型安全性检查和重用代码的能力。使用泛型可以在编译时检测类型错误,并减少类型转换的需要。 2.Java中的反射(…...

(03)vite 处理 css

文章目录 系列全集vite 处理css流程vite如何解决协同开发,样式重复覆盖的问题?使用less通过配置,更改vite的css默认行为vite 利用postcss样式兼容低版本浏览器 系列全集 (01)vite 从启动服务器开始 (02&am…...

阿里云上传文件出现的问题解决(跨域设置)

跨域设置引起的问题 起因&#xff1a;开通对象存储服务后&#xff0c;上传文件限制在5M 大小&#xff0c;无法上传大文件。 1.查看报错信息 2.分析阿里云服务端响应内容 <?xml version"1.0" encoding"UTF-8"?> <Error><Code>Invali…...

利用JavaFX生成验证码图片

以下是一个基于 JavaFX 的验证码图片生成小程序的示例代码: import javafx.application.Application; import javafx.embed.swing.SwingFXUtils; import javafx.scene.Scene; import javafx.scene.canvas.Canvas; import javafx.scene.canvas.GraphicsContext; import javafx…...

6-55.汽车类的继承

根据给定的汽车类vehicle&#xff08;包含的数据成员有车轮个数wheels和车重weight&#xff09;声明&#xff0c;完成其中成员函数的定义&#xff0c;之后再定义其派生类并完成测试。 小车类car是它的派生类&#xff0c;其中包含载人数passenger_load。每个类都有相关数据的输出…...

SCI论文——respectively用法

respectively用于配对两组&#xff08;三组&#xff09;事物&#xff0c;表明后一组与前一组按照相同的顺序排列&#xff0c;从而使句意明确。一般是在句子的最后&#xff0c;而且在respectively的前面需要一个逗号“,” 一 、两组事物&#xff1a; 原则是尽可能靠近第二组的…...

解决方案 | 法大大电子签约加速农牧业数字化进程

近年来&#xff0c;我国农业技术得到快速发展&#xff0c;并开发出一批实用的数字农业技术产品&#xff0c;建立了专用网络数字农业技术平台。数字农业是农业现代化的高级阶段&#xff0c;是创新推动农业农村信息化发展的有效手段&#xff0c;也是我国由农业大国迈向农业强国的…...

设计模式之GoF23介绍

深入探讨设计模式&#xff1a;构建可维护、可扩展的软件架构 一、设计模式的背景1.1 什么是设计模式1.2 设计模式的历史 二、设计模式的分类2.1 创建型模式2.2 结构型模式2.3 行为型模式 三、七大设计原则四、设计模式关系结论 :rocket: :rocket: :rocket: 在软件开发领域&…...

UDP协议实现群聊

服务端 import java.io.*; import java.net.*; import java.util.ArrayList; public class Server{public static ServerSocket server_socket;public static ArrayList<Socket> socketListnew ArrayList<Socket>(); public static void main(String []args){try{…...

lombok原理 @Slf4j 怎么生成get set log

Lombok是一种Java库&#xff0c;通过注解的方式提供了许多有用的功能&#xff0c;包括生成Getter、Setter、日志等。Slf4j注解是Lombok中的一种&#xff0c;它用于自动生成日志记录器&#xff08;Logger&#xff09;。 下面简要介绍一下Lombok的原理&#xff0c;以及Slf4j注解…...

【目标检测】进行实时检测计数时,在摄像头窗口显示实时计数个数

这里我是用我本地训练的基于yolov8环境的竹签计数模型&#xff0c;在打开摄像头窗口增加了实时计数显示的代码&#xff0c;可以直接运行&#xff0c;大家可以根据此代码进行修改&#xff0c;其底层原理时将检测出来的目标的个数显示了出来。 该项目链接&#xff1a;【目标检测…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架&#xff0c;它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用&#xff0c;和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

大数据学习栈记——Neo4j的安装与使用

本文介绍图数据库Neofj的安装与使用&#xff0c;操作系统&#xff1a;Ubuntu24.04&#xff0c;Neofj版本&#xff1a;2025.04.0。 Apt安装 Neofj可以进行官网安装&#xff1a;Neo4j Deployment Center - Graph Database & Analytics 我这里安装是添加软件源的方法 最新版…...

【杂谈】-递归进化:人工智能的自我改进与监管挑战

递归进化&#xff1a;人工智能的自我改进与监管挑战 文章目录 递归进化&#xff1a;人工智能的自我改进与监管挑战1、自我改进型人工智能的崛起2、人工智能如何挑战人类监管&#xff1f;3、确保人工智能受控的策略4、人类在人工智能发展中的角色5、平衡自主性与控制力6、总结与…...

【Linux】shell脚本忽略错误继续执行

在 shell 脚本中&#xff0c;可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行&#xff0c;可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令&#xff0c;并忽略错误 rm somefile…...

Unity3D中Gfx.WaitForPresent优化方案

前言 在Unity中&#xff0c;Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染&#xff08;即CPU被阻塞&#xff09;&#xff0c;这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案&#xff1a; 对惹&#xff0c;这里有一个游戏开发交流小组&…...

DockerHub与私有镜像仓库在容器化中的应用与管理

哈喽&#xff0c;大家好&#xff0c;我是左手python&#xff01; Docker Hub的应用与管理 Docker Hub的基本概念与使用方法 Docker Hub是Docker官方提供的一个公共镜像仓库&#xff0c;用户可以在其中找到各种操作系统、软件和应用的镜像。开发者可以通过Docker Hub轻松获取所…...

【项目实战】通过多模态+LangGraph实现PPT生成助手

PPT自动生成系统 基于LangGraph的PPT自动生成系统&#xff0c;可以将Markdown文档自动转换为PPT演示文稿。 功能特点 Markdown解析&#xff1a;自动解析Markdown文档结构PPT模板分析&#xff1a;分析PPT模板的布局和风格智能布局决策&#xff1a;匹配内容与合适的PPT布局自动…...

学习STC51单片机31(芯片为STC89C52RCRC)OLED显示屏1

每日一言 生活的美好&#xff0c;总是藏在那些你咬牙坚持的日子里。 硬件&#xff1a;OLED 以后要用到OLED的时候找到这个文件 OLED的设备地址 SSD1306"SSD" 是品牌缩写&#xff0c;"1306" 是产品编号。 驱动 OLED 屏幕的 IIC 总线数据传输格式 示意图 …...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案

随着新能源汽车的快速普及&#xff0c;充电桩作为核心配套设施&#xff0c;其安全性与可靠性备受关注。然而&#xff0c;在高温、高负荷运行环境下&#xff0c;充电桩的散热问题与消防安全隐患日益凸显&#xff0c;成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...