Detecting COVID19 on XRay Images with CNN

Computer vision has revolutionized the healthcare and medical industry by providing powerful tools for the detection and diagnosis of medical conditions and diseases. With the help of computer applications trained to recognize patterns on medical image data, it is now possible to screen and detect medical conditions at scale. One of the many applications of computer vision in this field is the detection of diseases on medical image data such as x-rays.

In this note, we explore the use of Convolutional Neural Networks (CNNs) to develop a classification model capable of predicting COVID-19 cases based on x-ray images. To achieve this, we leverage the COVID-19 x-ray dataset available at https://github.com/ieee8023/covid-chestxray-dataset

COVID DATASET

The COVID dataset used to train the model contains a total of 184 x-ray images. The images are split in the following way.

Train:
$covid: 112$, $Normal: 112$
Validation:
$covid: 30$, $Normal: 30$

For more details on the dataset, see the github link provided.

1. Dataset Prep and Visualization.

The first step is to download, process and visualize the data that we will be using for training. To process the data, we will use torchvision.datasets.ImageFolder() for image manipulation such as resizing, cropping and converting the images into a PyTorch dataset.

An important note here is that the datasets.ImageFolder() class expects a folder structure for which each subfolder represents a class/category and all images in the subfolder are of that class. For example, all covid positive images in the train folder will be in Covid subfolder and normal images will be in the Normal subfolder.

Here is an example of how to structure your image data.

! tree data -d
data
├── Train
│   ├── Covid
│   └── Normal
└── Val
    ├── Covid
    └── Normal

1.1. Importing images into PyTorch Dataset with datasets.ImageFolder()

torchvision's ImageFolder() method is a useful utility class that converts raw images from folders to a dataset object that can interact with PyTorch modules. It takes two arguments: root_folder_path and transforms.

Creating transforms is important because images will need to be resized and converted to tensors before using them for model training.

import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision import transforms, datasets

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

For the transform, I implement two sequence of transformations.

  • 1. Resize all photos to a 224 x 224 x 3
  • 2. Crop Image from the center.
transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor()])
train_dataset = datasets.ImageFolder('data/Train/', transform=transform)

The train_dataset is now a PyTorch dataset class. This mean, I can retrive individual image tensors and their class.

train_dataset[-1][0]
tensor([[[0.8510, 0.8510, 0.8392,  ..., 0.3451, 0.3294, 0.3216],
    [0.6863, 0.6784, 0.6706,  ..., 0.2588, 0.2431, 0.2353],
    [0.5451, 0.5373, 0.5255,  ..., 0.2235, 0.2118, 0.2000],
    ...,
    [0.0078, 0.0471, 0.0745,  ..., 0.4039, 0.4118, 0.4157],
    [0.0039, 0.0471, 0.0745,  ..., 0.4235, 0.4275, 0.4314],
    [0.0039, 0.0392, 0.0706,  ..., 0.4275, 0.4353, 0.4392]],

   [[0.8510, 0.8510, 0.8392,  ..., 0.3451, 0.3294, 0.3216],
    [0.6863, 0.6784, 0.6706,  ..., 0.2588, 0.2431, 0.2353],
    [0.5451, 0.5373, 0.5255,  ..., 0.2235, 0.2118, 0.2000],
    ...,
    [0.0078, 0.0471, 0.0745,  ..., 0.4039, 0.4118, 0.4157],
    [0.0039, 0.0471, 0.0745,  ..., 0.4235, 0.4275, 0.4314],
    [0.0039, 0.0392, 0.0706,  ..., 0.4275, 0.4353, 0.4392]],

   [[0.8510, 0.8510, 0.8392,  ..., 0.3451, 0.3294, 0.3216],
    [0.6863, 0.6784, 0.6706,  ..., 0.2588, 0.2431, 0.2353],
    [0.5451, 0.5373, 0.5255,  ..., 0.2235, 0.2118, 0.2000],
    ...,
    [0.0078, 0.0471, 0.0745,  ..., 0.4039, 0.4118, 0.4157],
    [0.0039, 0.0471, 0.0745,  ..., 0.4235, 0.4275, 0.4314],
    [0.0039, 0.0392, 0.0706,  ..., 0.4275, 0.4353, 0.4392]]])

1.2. Visualizing X-Ray Images

For visualizing an image or set of images, I create a simple helper function. The helper function will visualize both invidiual images and batches of images together.

def viewXRay(img_input, batch=False):
    
    # Subplots for batches
    if batch:
        fig = plt.figure(figsize=(25,20))
        fig.subplots_adjust(bottom=0.025, left=0.025, top = 0.975, right=0.975)
        plot_dim = len(img_input) // 4
        for i in range(len(img_input)):
            fig.add_subplot(plot_dim, plot_dim, i+1)
            plt.imshow(img_input[i].permute(1, 2, 0))
            plt.tight_layout()
        
    else:
        fig = plt.figure(figsize=(8,4))
        plt.imshow(img_input.permute(1, 2, 0))
        plt.tight_layout()

    plt.show()
viewXRay(train_dataset[-1][0])
Detecting CoVID on xRAy with CNN
viewXRay(train_dataset[5][0])
Detecting COVID on xRays with CNN

1.3. Visualization for Batches

The helper function also covers batch views. This way, we can pass the batch from the dataloader directly into the function and it will return subplots with xray images.

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

images, labels = next(iter(train_dataloader))
viewXRay(images, batch=True)
Detecting COVID on xRays with CNN

2. Building CNN Model

To the trained eye, these x-ray images may make sense. For the rest of us, leveraging deep learning to learn features in the image may be the best bet in identifying covid x-rays and normal x-rays. The model Architecture I am using will have 3 Convolution Layers and 2 linear layers. I add MaxPooling in the subsequent convolution layer and some dropout layers for regularization.

The implementations of the architecture is available below:

import numpy as np
import torch.nn as nn
import torch.optim as optim

class CovidXRayDetection(nn.Module):
    
    def __init__(self, n_classes=2):
        super(CovidXRayDetection, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d( in_channels=3, out_channels=32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(.25),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(.25),
        )
        
        self.avg_pool = nn.AdaptiveAvgPool2d(2)
        self.classifier = nn.Sequential(
            nn.Linear(128 * 2 * 2, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(.5),
            nn.Linear(64, 1),
        )
        
        
    def forward(self, x):
        x = self.features(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

3. Model Initialization and Training

To train the model, I have opted to use the Adam optimizer with a learning rate of .01, and the loss function will be Binary Cross Entropy. This is because the target outcomes of the predictions are binary, and using BCEWithLogitsLoss ensures that our model does not produce probabilities. Instead, the probabilities will be calculated using the loss function.

Implementation of the training routine:

# initialize the model
model = CovidXRayDetection()

# parameters
learning_rate = .01
optimizer = optim.Adam( model.parameters(), lr=learning_rate)

# Loss function
loss_function = nn.BCEWithLogitsLoss(reduction='mean', pos_weight=torch.tensor([1.]))

# epochs
epochs = 50
losses = []

for epoch in range(epochs):
    
    mini_batch_losses =[]
    for x_batch, y_batch in train_dataloader:
        
        model.train()
        
        y_pred = model(x_batch)
        
        loss = loss_function(y_pred, y_batch.unsqueeze(1).float())
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        mini_batch_losses.append(loss.item())
        
        
        
    loss = np.mean(mini_batch_losses)
    losses.append(loss)
    
    print(f"Epoch: {epoch}, Loss: {loss}")
Epoch: 0, Loss: 0.7217250019311905
Epoch: 1, Loss: 0.7006456106901169
Epoch: 2, Loss: 0.6594521775841713
Epoch: 3, Loss: 0.622551117092371
Epoch: 4, Loss: 0.4777718745172024
Epoch: 5, Loss: 0.42663420736789703
Epoch: 6, Loss: 0.3091115616261959
Epoch: 7, Loss: 0.2671959847211838
Epoch: 8, Loss: 0.26181135699152946
Epoch: 9, Loss: 0.2637191619724035
Epoch: 10, Loss: 0.2649126425385475
Epoch: 11, Loss: 0.2811101209372282
Epoch: 12, Loss: 0.23474631272256374
Epoch: 13, Loss: 0.23240581166464835
Epoch: 14, Loss: 0.24001342244446278
Epoch: 15, Loss: 0.22849101945757866
Epoch: 16, Loss: 0.21720479056239128
Epoch: 17, Loss: 0.19048296194523573
Epoch: 18, Loss: 0.1754478020593524
Epoch: 19, Loss: 0.18594214040786028
Epoch: 20, Loss: 0.20017251884564757
Epoch: 21, Loss: 0.23603381495922804
Epoch: 22, Loss: 0.1866110609844327
Epoch: 23, Loss: 0.20468028541654348
Epoch: 24, Loss: 0.2053750939667225
Epoch: 25, Loss: 0.2322987513616681
Epoch: 26, Loss: 0.1793487872928381
Epoch: 27, Loss: 0.15538806840777397
Epoch: 28, Loss: 0.20436909375712276
Epoch: 29, Loss: 0.19200498051941395
Epoch: 30, Loss: 0.15856635384261608
Epoch: 31, Loss: 0.15390582103282213
Epoch: 32, Loss: 0.19708898663520813
Epoch: 33, Loss: 0.21701610274612904
Epoch: 34, Loss: 0.3220353554934263
Epoch: 35, Loss: 0.24722643941640854
Epoch: 36, Loss: 0.22207679599523544
Epoch: 37, Loss: 0.17871477361768484
Epoch: 38, Loss: 0.15754215978085995
Epoch: 39, Loss: 0.1523445942439139
Epoch: 40, Loss: 0.20394558552652597
Epoch: 41, Loss: 0.2540366337634623
Epoch: 42, Loss: 0.20610464923083782
Epoch: 43, Loss: 0.2029423825442791
Epoch: 44, Loss: 0.2104199305176735
Epoch: 45, Loss: 0.16047851648181677
Epoch: 46, Loss: 0.16203059442341328
Epoch: 47, Loss: 0.1629633717238903
Epoch: 48, Loss: 0.1281860782764852
Epoch: 49, Loss: 0.10838314751163125
fig = plt.figure(figsize=(10,3))

plt.plot(range(1,51) , losses)
plt.title('Model Training Loss by Epoch')
plt.xlabel('Epochs')
plt.ylabel('Loss')
Detecting COVID on xRays with CNN

4. Model Prediction

To keep things simple, I have decided not to run the model against a validation set for loss analysis. Instead, I will use the validation set as a test set to evaluate the accuracy of the model. I will assess the model's performance by setting a threshold of .5 for Covid predictions.

validation_dataset = datasets.ImageFolder('data/Val/', transform=transform)
validation_dataloader = DataLoader(validation_dataset, batch_size=32, shuffle=True)
test_predictions, test_outcomes = [], []

model.eval()

for x_input, y_outcomes in validation_dataloader:
    
    # predictions
    for pred in (torch.sigmoid(model(x_input)) > .5).squeeze():
        test_predictions.append(pred.item())
        
    # outcomes
    for outcome in y_outcomes:
            test_outcomes.append(outcome.item())
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve, auc, accuracy_score
print('Accuracy Score:', accuracy_score( test_outcomes, test_predictions, normalize=True)) 
Accuracy Score: 0.9666666666666667
sns.heatmap(confusion_matrix(test_outcomes, test_predictions), annot=True,)
        plt.title('Test Data Confusion Matrix')
        plt.xlabel('Predicted Labels')
        plt.ylabel('True Labels')
        plt.show()
Detecting COVID on xRays with CNN Model Confusion Matrix