Vision: Number of input channels of DenseNets

Created on 13 Oct 2019  路  2Comments  路  Source: pytorch/vision

Hi all, I'm using data with only 2 channels, instead of common RGB 3 channels.

By default, the number of input channel of DenseNets is 3 in torchvision. What is the best way to modify #input_channels or can we add a new arg like num_input_channels to it?

models question classification

Most helpful comment

Since all our models where trained on RGB images, I'm almost certain you won't get meaningful results with the pretrained weights. So you need to train from scratch.

For example, you can swap the input layer like this:

from torch import nn
from torchvision.models import densenet121

num_input_channels = 2

old_model = densenet121()
old_features = old_model.features
old_input_layer = old_features[0]

conv_args = [getattr(old_input_layer, attr) for attr in ("out_channels", "kernel_size")]
conv_kwargs = {
    attr: getattr(old_input_layer, attr)
    for attr in ("stride", "padding", "dilation", "groups", "bias", "padding_mode")
}
new_input_layer = nn.Conv2d(num_input_channels, *conv_args, **conv_kwargs)

new_features = nn.Sequential(new_input_layer, *old_features[1:])
new_model = old_model
new_model.features = new_features

print(new_model)
DenseNet(
  (features): Sequential(
    (0): Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): _DenseBlock(
...

All 2 comments

Since all our models where trained on RGB images, I'm almost certain you won't get meaningful results with the pretrained weights. So you need to train from scratch.

For example, you can swap the input layer like this:

from torch import nn
from torchvision.models import densenet121

num_input_channels = 2

old_model = densenet121()
old_features = old_model.features
old_input_layer = old_features[0]

conv_args = [getattr(old_input_layer, attr) for attr in ("out_channels", "kernel_size")]
conv_kwargs = {
    attr: getattr(old_input_layer, attr)
    for attr in ("stride", "padding", "dilation", "groups", "bias", "padding_mode")
}
new_input_layer = nn.Conv2d(num_input_channels, *conv_args, **conv_kwargs)

new_features = nn.Sequential(new_input_layer, *old_features[1:])
new_model = old_model
new_model.features = new_features

print(new_model)
DenseNet(
  (features): Sequential(
    (0): Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): _DenseBlock(
...

Since all our models where trained on RGB images, I'm almost certain you won't get meaningful results with the pretrained weights. So you need to train from scratch.

For example, you can swap the input layer like this:

from torch import nn
from torchvision.models import densenet121

num_input_channels = 2

old_model = densenet121()
old_features = old_model.features
old_input_layer = old_features[0]

conv_args = [getattr(old_input_layer, attr) for attr in ("out_channels", "kernel_size")]
conv_kwargs = {
    attr: getattr(old_input_layer, attr)
    for attr in ("stride", "padding", "dilation", "groups", "bias", "padding_mode")
}
new_input_layer = nn.Conv2d(num_input_channels, *conv_args, **conv_kwargs)

new_features = nn.Sequential(new_input_layer, *old_features[1:])
new_model = old_model
new_model.features = new_features

print(new_model)
DenseNet(
  (features): Sequential(
    (0): Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): _DenseBlock(
...

It works. Thanks a lot!

Was this page helpful?
0 / 5 - 0 ratings