I am Charmie

メモとログ

PyTorch: visualize trained filters and feature maps

I just wrote a simple code to visualize trained filters and feature maps of pytorch. For simplicity, the below code uses pretrained AlexNet but the code must work with any network with Conv2d layers.

[code lang="python"]

!/usr/bin/env python3

-- coding: utf-8 --

import os from itertools import product import numpy as np

import imageio

import torch import torch.nn as nn import torchvision.models as models from torch.autograd import Variable

def visualize_conv_filters(net, dirname, dataname): dirname = os.path.join(dirname, dataname) if not os.path.exists(dirname): os.mkdir(dirname)

index_conv = 0
for module in net.modules():
    if isinstance(module, nn.Conv2d):
        filters = module.weight.detach().numpy()

        shape = filters.shape
        filters_tile = np.zeros((shape[0]*shape[2],
                                 shape[1]*shape[3]))

        for i, j in product(range(shape[0]), range(shape[1])):
            filters_tile[i*shape[2]:(i+1)*shape[2],
                         j*shape[3]:(j+1)*shape[3]] = filters[i, j, :, :]
        filename = '%s_conv%d.png' % (dataname, index_conv)
        imageio.imwrite(os.path.join(dirname, filename),
                        filters_tile)

        index_conv += 1

def visualize_feature_maps(net, dirname, dataname, x): dirname_root = os.path.join(dirname, dataname) if not os.path.exists(dirname_root): os.mkdir(dirname_root)

net.eval()
for index, layer in enumerate(list(net.features)):
    dirname = os.path.join(dirname_root, 'layer'+str(index))
    if not os.path.exists(dirname):
        os.mkdir(dirname)

    x = layer(x)
    feature = x.detach().numpy()<span id="mce_SELREST_start" style="overflow:hidden;line-height:0;"></span>
    for i in range(feature.shape[1]):
        filename = 'layer%d_feature%d.png' % (index, i)
        imageio.imwrite(os.path.join(dirname, filename),
                        feature[0, i, :, :])

if name == 'main': net = models.alexnet(pretrained=True)

visualize_conv_filters(net, './data/alexnet', 'filters')

x = Variable(torch.randn(1, 3, 224, 224))
visualize_feature_maps(net, './data/alexnet', 'random', x)

[/code]