pytorch日积月累3-计算图与动态图机制
1.计算图
计算图是用来描述运算的有向无环图。
计算图有两个主要元素:结点(Node)和边(Edge)
结点表示数据,如向量,矩阵,张量。边表示运算,如:加减乘除和卷积等。
用计算图表示:
通过分析可以知道,$y$对$w$求导就是在计算图中找到所有y到w的路径,把路径上的导数进行求和。
1 | import torch |
叶子节点:用户创建的结点称为叶子结点,如$X$与$W$;
is_leaf:指示张量是否为叶子节点;
叶子节点是整个计算图的根基,例如前面求导的计算图,在前向传导中的$a$、$b$和$y$都要依据创建的叶子节点$x$和$w$进行计算的。同样,在反向传播过程中,所有梯度的计算都要依赖叶子节点。
设置叶子节点主要是为了节省内存,在梯度反向传播结束之后,非叶子节点的梯度都会被释放掉。
1 | #查看叶子结点,通过运算得来的结点不是叶子结点 |
如果想使用非叶子结点梯度,可以使用pytorch中retain_grad()
。例如对上面代码中的$a$执行相关操作a.retain_grad()
,则$a$的梯度会被保留下来,$b$和$y$的梯度会被释放掉。
1 | a.retain_grad() |
torch.Tensor
中还有一个属性为grad_fn
,grad_fn
的作用是记录创建该张量时所用的方法(函数),该属性在梯度反向传播的时候用到。
1 | # 查看 grad_fn |
2.动态图
动态图:pytorch使用的,运算与搭建同时进行;灵活,易调节。
静态图:tensorflow使用的,先搭建图,后运算;高效,不灵活。
根据计算图搭建方式,可将计算图分为动态图和静态图。