ConvMixer
ConvMixer model.
Implementation of ConvMixer.
ConvMixer - ICLR 2022 submission "Patches Are All You Need?".
Adopted from https://github.com/tmp-iclr/convmixer
Home for convmixer: https://github.com/locuslab/convmixer
Purpose of this implementation - possibilities for tune this model.
For example - play with activation function, initialization etc.
Import and create model
Base class for model - ConvMixer, return pytorch Sequential model.
from model_constructor import ConvMixer
Now we can create convmixer model:
convmixer_1024_20 = ConvMixer(dim=1024, depth=20)
convmixer_1024_20
output
ConvMixer( (0): ConvLayer( (conv): Conv2d(3, 1024, kernel_size=(7, 7), stride=(7, 7)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (3): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (4): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (5): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (6): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (7): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (8): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (9): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (10): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (11): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (12): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (13): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (14): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (15): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (16): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (17): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (18): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (19): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (20): Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): GELU(approximate='none') (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (21): AdaptiveAvgPool2d(output_size=(1, 1)) (22): Flatten(start_dim=1, end_dim=-1) (23): Linear(in_features=1024, out_features=1000, bias=True) )
Change activation function.
Lets create model with Mish (import it from torch) instead of GELU.
convmixer_1024_20 = ConvMixer(dim=1024, depth=20, act_fn=Mish())
convmixer_1024_20[0]
output
ConvLayer( (conv): Conv2d(3, 1024, kernel_size=(7, 7), stride=(7, 7)) (act_fn): Mish() (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )
convmixer_1024_20[1]
output
Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (act_fn): Mish() (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (act_fn): Mish() (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) )
Pre activation
Activation function before convolution.
convmixer_1024_20 = ConvMixer(dim=1024, depth=20, act_fn=Mish(), pre_act=True)
convmixer_1024_20[0]
output
ConvLayer( (conv): Conv2d(3, 1024, kernel_size=(7, 7), stride=(7, 7)) (act_fn): Mish() (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )
convmixer_1024_20[1]
output
Sequential( (0): Residual( (fn): ConvLayer( (act_fn): Mish() (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): ConvLayer( (act_fn): Mish() (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) )
BatchNorm before activation.
convmixer_1024_20 = ConvMixer(dim=1024, depth=20, act_fn=Mish(), bn_1st=True)
convmixer_1024_20[0]
output
ConvLayer( (conv): Conv2d(3, 1024, kernel_size=(7, 7), stride=(7, 7)) (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act_fn): Mish() )
convmixer_1024_20[1]
output
Sequential( (0): Residual( (fn): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024) (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act_fn): Mish() ) ) (1): ConvLayer( (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act_fn): Mish() ) )