Module warm.module
Custom modules to enhance the nn Sequential experience.
PyWarm's core concept is to use a functional interface to simplify network building.
However, if you still prefer the classical way of defining child modules in __init__()
,
PyWarm provides some utilities to help organize child modules better.
-
Lambda
can be used to wrap one line data transformations, likex.view()
,x.permute()
etc, into modules. -
Sequential
is an extension tonn.Sequential
that better accomodates PyTorch RNNs. -
Shortcut
is another extension tonn.Sequential
that will also perform a shortcut addition (AKA residual connection) for the input with output, so that residual blocks can be written in an entire sequential way.
For example, to define the basic block type for resnet:
import torch.nn as nn
import warm.module as wm
def basic_block(size_in, size_out, stride=1):
block = wm.Shortcut(
nn.Conv2d(size_in, size_out, 3, stride, 1, bias=False),
nn.BatchNorm2d(size_out),
nn.ReLU(),
nn.Conv2d(size_out, size_out, 3, 1, 1, bias=False),
nn.BatchNorm2d(size_out),
projection=wm.Lambda(
lambda x: x if x.shape[1] == size_out else nn.Sequential(
nn.Conv2d(size_in, size_out, 1, stride, bias=False),
nn.BatchNorm2d(size_out), )(x), ), )
return block
Classes
Lambda
def :
fn,
*arg,
**kw
Wraps a callable and all its call arguments.
fn: callable
; The callable being wrapped.*arg: list
; Arguments to be passed tofn
.**kw: dict
; KWargs to be passed tofn
.
Ancestors (in MRO)
- torch.nn.modules.module.Module
Methods
forward
def :
self,
x
Sequential
def :
*args
Similar to nn.Sequential
, except that child modules can have multiple outputs (e.g. nn.RNN
).
*arg: list of Modules
; Same asnn.Sequential
.
Ancestors (in MRO)
- torch.nn.modules.container.Sequential
- torch.nn.modules.module.Module
Descendants
- warm.module.Shortcut
Methods
forward
def :
self,
x
Shortcut
def :
*arg,
projection=None
Similar to nn.Sequential
, except that it performs a shortcut addition for the input and output.
*arg: list of Modules
; Same asnn.Sequential
.projection: None or callable
; IfNone
, input with be added directly to the output. otherwise input will be passed to theprojection
first, usually to make the shapes match.
Ancestors (in MRO)
- warm.module.Sequential
- torch.nn.modules.container.Sequential
- torch.nn.modules.module.Module