AI
用户4984
分享
深度学习计算图中的"节点"概念
输入“/”快速插入内容
深度学习计算图中的"节点"概念
用户4984
用户4984
2024年10月25日修改
在深度学习中,节点(Node)是计算图中的基本单位,可以分为三种主要类型:
1.
叶子节点(Leaf Nodes)
◦
网络的参数(如权重 weights、偏置 bias)
◦
输入数据(如特征、标签)
◦
特点:
▪
requires_grad=True 时会记录梯度
▪
梯度存储在 .grad 属性中
▪
没有 grad_fn(因为是起始节点)
2.
中间节点(Intermediate Nodes)
◦
计算操作的结果(如两个张量的乘积)
◦
临时变量(如层与层之间的激活值)
◦
特点:
▪
有 grad_fn 属性,记录反向传播的计算函数
▪
默认不保存梯度(除非特意设置 retain_grad=True)
▪
可以访问 grad_fn 查看梯度计算方式
3.
输出节点(Output Nodes)
◦
损失函数值(loss)
◦
网络的最终输出
◦
特点:
▪
通常是反向传播的起点
▪
调用 .backward() 开始梯度计算
示例代码:
代码块
Python
import torch
# 叶子节点
w = torch.tensor([2.0], requires_grad=True) # 权重
x = torch.tensor([3.0], requires_grad=True) # 输入
# 中间节点
z = w * x # 乘法操作的结果
# 输出节点
loss = z ** 2 # 损失值
print(f"叶子节点grad_fn: {w.grad_fn}") # None
print(f"中间节点grad_fn: {z.grad_fn}") # MulBackward
print(f"输出节点grad_fn: {loss.grad_fn}") # PowBackward