Skip to content

ConvMixer

ConvMixer model.

Implementation of ConvMixer.
ConvMixer - ICLR 2022 submission "Patches Are All You Need?".
Adopted from https://github.com/tmp-iclr/convmixer
Home for convmixer: https://github.com/locuslab/convmixer

Purpose of this implementation - possibilities for tune this model.
For example - play with activation function, initialization etc.

Import and create model

Base class for model - ConvMixer, return pytorch Sequential model.

from model_constructor import ConvMixer

Now we can create convmixer model:

convmixer_1024_20 = ConvMixer(dim=1024, depth=20)
convmixer_1024_20
output
ConvMixer(
      (0): ConvLayer(
        (conv): Conv2d(3, 1024, kernel_size=(7, 7), stride=(7, 7))
        (act_fn): GELU(approximate='none')
        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (3): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (4): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (5): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (6): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (7): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (8): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (9): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (10): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (11): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (12): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (13): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (14): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (15): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (16): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (17): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (18): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (19): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (20): Sequential(
        (0): Residual(
          (fn): ConvLayer(
            (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
            (act_fn): GELU(approximate='none')
            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (act_fn): GELU(approximate='none')
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (21): AdaptiveAvgPool2d(output_size=(1, 1))
      (22): Flatten(start_dim=1, end_dim=-1)
      (23): Linear(in_features=1024, out_features=1000, bias=True)
    )

Change activation function.

Lets create model with Mish (import it from torch) instead of GELU.

convmixer_1024_20 = ConvMixer(dim=1024, depth=20, act_fn=Mish())
convmixer_1024_20[0]
output
ConvLayer(
      (conv): Conv2d(3, 1024, kernel_size=(7, 7), stride=(7, 7))
      (act_fn): Mish()
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
convmixer_1024_20[1]
output
Sequential(
      (0): Residual(
        (fn): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
          (act_fn): Mish()
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ConvLayer(
        (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        (act_fn): Mish()
        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )

Pre activation

Activation function before convolution.

convmixer_1024_20 = ConvMixer(dim=1024, depth=20, act_fn=Mish(), pre_act=True)
convmixer_1024_20[0]
output
ConvLayer(
      (conv): Conv2d(3, 1024, kernel_size=(7, 7), stride=(7, 7))
      (act_fn): Mish()
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
convmixer_1024_20[1]
output
Sequential(
      (0): Residual(
        (fn): ConvLayer(
          (act_fn): Mish()
          (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ConvLayer(
        (act_fn): Mish()
        (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )

BatchNorm before activation.

convmixer_1024_20 = ConvMixer(dim=1024, depth=20, act_fn=Mish(), bn_1st=True)
convmixer_1024_20[0]
output
ConvLayer(
      (conv): Conv2d(3, 1024, kernel_size=(7, 7), stride=(7, 7))
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): Mish()
    )
convmixer_1024_20[1]
output
Sequential(
      (0): Residual(
        (fn): ConvLayer(
          (conv): Conv2d(1024, 1024, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=1024)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act_fn): Mish()
        )
      )
      (1): ConvLayer(
        (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_fn): Mish()
      )
    )