#04 构建您的第一个神经网络:PyTorch入门指南

文章目录

  • 前言
    • 理论基础
      • 神经网络层的组成
      • 前向传播与反向传播
    • 神经网络设计
      • 步骤1:准备数据集
      • 步骤2:构建模型
      • 步骤3:定义损失函数和优化器
      • 步骤4:训练模型
      • 步骤5:评估模型
      • 结论


前言

  在过去的几天里,我们深入了解了深度学习的基础,从安装PyTorch开始,到理解Tensor运算,以及自动微分系统Autograd的工作原理。今天,我们将整合我们所学的知识和技能,并迈出一个重要的步伐:构建和训练我们的第一个神经网络。
在这里插入图片描述

  在本篇博文中,我会指导您如何设计一个简单的全连接神经网络(也称为多层感知机MLP),并使用PyTorch作为您的工具。我们的目标是用这个网络解决一个分类问题,并了解神经网络训练的基本流程。让我们一步一步地进行。

理论基础

  在开始编码之前,让我们先回顾一下全连接神经网络的基本组件和原理。一个典型的神经网络由多个层组成,每个层由多个神经元组成。在全连接网络中,一个层中的每个神经元都与前一层的所有神经元连接。这种密集的连接模式能让网络从输入数据中捕获复杂的模式。

神经网络层的组成

神经网络中的每一层主要包含以下部分:

  1. 输入节点:这些是将数据输入到网络的节点。
  2. 权重:每个连接都有一个权重,它决定了输入信号对神经元的激活程度的影响。
  3. 偏置:每个神经元都有一个偏置,它提供了除了输入以外的另一个调节激活的手段。
  4. 激活函数:它决定了神经元的输出,通常是一个非线性函数,如ReLU或Sigmoid。

前向传播与反向传播

神经网络的训练分为两个阶段:前向传播和反向传播。

  • 前向传播:在这个阶段,输入数据通过网络的每一层,每个神经元根据其权重、偏置和激活函数计算输出。
  • 反向传播:训练的这一部分涉及根据网络输出和实际结果之间的差异来调整权重和偏置。这是通过计算损失函数的梯度并将其反向传播到网络中完成的。

神经网络设计

现在,让我们设计一个简单的神经网络。假设我们要解决的是一个二分类问题,我们的网络将有两个输入节点(对应于两个特征),几个隐藏层,以及一个输出节点。

步骤1:准备数据集

在构建模型之前,我们首先需要准备数据集。通常,我们会对数据进行预处理,如标准化,然后分为训练集和测试集。

import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

# 假设data为特征,labels为标签
data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.2)
train_dataset = TensorDataset(torch.tensor(data_train, dtype=torch.float32), torch.tensor(labels_train, dtype=torch.float32))
test_dataset = TensorDataset(torch.tensor(data_test, dtype=torch.float32), torch.tensor(labels_test, dtype=torch.float32))

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

步骤2:构建模型

  接下来,我们将使用PyTorch定义我们的神经网络。我们将使用nn.Module作为基类,并定义我们的层和前向传播方法。

import torch.nn as nn
import torch.nn.functional as F

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(2, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

model = SimpleNN()

步骤3:定义损失函数和优化器

  损失函数用于评估模型的预测与实际标签之间的差异。优化器用于根据损失函数的梯度更新模型的权重。

import torch.optim as optim

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

步骤4:训练模型

训练模型涉及多个迭代(或“epoch”),在每个迭代中,我们将完成一个完整的前向和后向传播过程。

for epoch in range(100):  # 举例迭代100次
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()

步骤5:评估模型

最后,我们需要评估我们的模型性能,通常使用测试集来完成这个任务。

correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        predicted = outputs.round().squeeze()  # 将输出四舍五入为0或1
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total}%')

结论

  在以上步骤中,我们从头开始构建了一个简单的神经网络,并在PyTorch中进行了训练和评估。这个过程涵盖了许多重要的概念和技巧,是进一步学习深度学习的坚实基础。

  值得注意的是,真实世界的神经网络要复杂得多,涉及到更复杂的数据预处理、更深的网络结构、正则化技术,以及更精细的训练技巧。但通过这个基本的例子,我们已经开始了深度学习之旅,并为未来的学习打下了坚实的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/612923.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Java】获取近六个月的年月

数据库里面存储的字段类型就是varchar&#xff0c;数据格式就是类似2024-12这样的年月格式。 目标&#xff1a; 以当前月份为标准&#xff0c;向前获取近6个月的年月&#xff08;year_month&#xff09;形成列表 // 获取近6个月的年月列表List<String> recentMonths ge…

fswatch工具:跟踪Linux中的文件和目录更改

fswatch是一个跨平台的文件更改监视器&#xff0c;当指定文件或目录的内容被更改或修改时&#xff0c;它会收到通知警报。 fswatch在不同的操作系统上执行多种类型的监视器&#xff0c;例如&#xff1a; 基于 Apple OS X 的文件系统事件 API 构建的监视器。基于kqueue的监视器…

Nginx部署前后端分离项目

部署前后端分离项目&#xff0c;要求前端项目、后端项目、数据库分别部署在3台服务器 服务器准备 服务器名IP软件包前端192.168.99.137nginx后端192.168.99.139jar数据库192.168.99.100mariadb 1、前端服务器 yum install -y epel-release && yum install -y nginx…

9.spring-图书管理系统

文章目录 1.开发项目流程1.1开发开发1.2数据库的设计 2.MySQL数据库相关代码3.构造图书结构3.1用户登录3.2图书列表3.3图书添加3.4图书删除3.4.1批量删除 3.5图书查询(翻页) 4.页面展示4.1登录页面4.2列表页面4.3增加图书页面4.4修改图书信息页面 5.功能展示5.1增加图书信息5.2…

分布式与一致性协议之TCC协议

TCC协议 概述 虽然MySQL XA能实现数据层的分布式事务&#xff0c;解决多个MySQL操作的事务问题&#xff0c;但还面临别的问题:在接收到外部的指令后&#xff0c;需要访问多个内部系统&#xff0c;执行指令约定的操作&#xff0c;还必须保证指令执行的原子性(也就是事务要么全…

nestjs 全栈进阶--中间件

视频教程 22_nest中中间件_哔哩哔哩_bilibili 1. 介绍 在Nest.js框架中&#xff0c;中间件&#xff08;Middleware&#xff09;是一个非常重要的概念&#xff0c;它是HTTP请求和响应生命周期中的一个重要组成部分&#xff0c;允许开发者在请求到达最终的目的控制器方法之前或…

机器学习算法 - 逻辑回归

逻辑回归是一种广泛应用于统计学和机器学习领域的回归分析方法&#xff0c;主要用于处理二分类问题。它的目的是找到一个最佳拟合模型来预测一个事件的发生概率。以下是逻辑回归的一些核心要点&#xff1a; 基本概念 输出&#xff1a;逻辑回归模型的输出是一个介于0和1之间的…

【数据结构】队列详解(Queue)

文章目录 有关队列的概念队列的结点设计及初始化队列的销毁判空和计数入队操作出队操作 有关队列的概念 队列:只允许在一端进行插入数据操作&#xff0c;在另一端进行删除数据操作的特殊线性表&#xff0c;队列具有先进先出FIFO(First In First Out)入队列:进行插入操作的一端…

Redis不同数据类型value存储

一、Strings redis中String的底层没有用c的char来实现&#xff0c;而是使用SDS数据结构( char buf[])。 缺点:浪费空间 优势: 1.c字符串不记录自身的长度&#xff0c;所以获取一个字符串长度的复杂度是O(N),但是SDS记录分配的长度alloc,已使用长度len&#xff0c;获取长度的…

CSS文字描边,文字间隔,div自定义形状切割

clip-path: polygon( 0 0, 68% 0, 100% 32%, 100% 100%, 0 100% );//这里切割出来是少一角的正方形 letter-spacing: 1vw; //文字间隔 -webkit-text-stroke: 1px #fff; //文字描边1px uniapp微信小程序顶部导航栏设置透明&#xff0c;下拉改变透明度 onP…

[js] 递归,数组对象根据某个值进行升序或者降序

一、效果图 1.1 父级 1.2 父级与子级 二、代码 升序降序&#xff0c;只要把 a.num - b.num 改成 b.num - a.num <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, i…

pycharm 将项目连同库一起打包及虚拟环境的使用

目录 一、创建虚拟环境 1、用 anaconda 创建 2、Pycharm 直接创建 二、虚拟环境安装第三方库 1、创建项目后&#xff0c;启动终端(Alt F12)&#xff0c;或者点击下方标记处。 2、使用 pip 或者 conda 来进行三方库的安装或卸载 3、将项目中的库放入文档&#xff0c;便于…

Dual Aggregation Transformer for Image Super-Resolution论文总结

题目&#xff1a;Dual Aggregation Transformer&#xff08;双聚合Transformer&#xff09; for Image Super-Resolution&#xff08;图像超分辨&#xff09; 论文&#xff08;ICCV&#xff09;&#xff1a;Chen_Dual_Aggregation_Transformer_for_Image_Super-Resolution_ICCV…

Android之给Button上添加按压效果

一、配置stateListAnimator参数实现按压效果 1、按钮控件 <Buttonandroid:id"id/mBtnLogin"android:layout_width"match_parent"android:layout_height"48dp"android:background"drawable/shape_jfrb_login_button"android:state…

腾讯共享WiFi项目的加盟方式有哪些?

在这个互联互通的时代&#xff0c;共享经济的浪潮正以前所未有的力量席卷全球&#xff0c;而腾讯作为中国互联网巨头之一自然不会错过这场盛宴。其推出的腾讯共享WiFi项目自问世以来就备受瞩目&#xff0c;它不仅为用户提供便捷的上网服务&#xff0c;更为创业者打开了一个全新…

Language2Pose: Natural Language Grounded Pose Forecasting # 论文阅读

URL https://arxiv.org/pdf/1907.01108 TD;DR 19 年 7 月 cmu 的文章&#xff0c;提出一种基于 natural language 生成 3D 动作序列的方法。通过一个简单的 CNN 模型应该就可以实现 Model & Method 首先定义一下任务&#xff1a; 输入&#xff1a;用户的自然语言&…

win10电脑桌面便签纸怎么设置?添加桌面便签方法

对于上班族来说&#xff0c;电脑桌面上的电子便签纸是一项不可或缺的工具。在快节奏的工作环境中&#xff0c;我们经常需要随时记录重要信息、安排工作任务&#xff0c;而电子便签纸以其便捷性和实时性成为了我们的得力助手。 想象一下&#xff0c;在紧张的项目讨论中&#xf…

mysql 细分

索引选择性 索引列的唯一值数量 / 表中的总行数 mysql如何优化-CSDN博客 批量问题 批处理默认是逐条发送 SQL 到数据库的&#xff0c;没有充分利用数据库提供的原生批处理能力&#xff0c;需要额外的配置来启用真正的批处理支持&#xff0c;如使用ExecutorType.BATCH 自定…

提升网络性能,解决网络故障,了解AnaTraf网络流量分析仪

在当今数字化时代&#xff0c;网络性能监测与诊断(Network Performance Monitoring and Diagnosis,NPMD)成为了企业和个人关注的焦点。随着网络流量不断增长&#xff0c;确保网络的稳定性和高效性变得更加重要。在这个领域&#xff0c;AnaTraf网络流量分析仪是您不可或缺的得力…

自然资源-土地征收成片开发知识梳理

自然资源-土地征收成片开发知识梳理 1、什么是成片开发 &#xff1f; 自然资源部印发的《土地征收成片开发标准&#xff08;试行&#xff09;》对成片开发的概念做了界定&#xff0c;成片开发是指在国土空间规划确定的城镇开发边界内的集中建设区&#xff0c;由县级以上地方人…