COSC 2673/2793 | Machine Learning | Assignment 2
¶

Student Names: Thomas Williams, Hannah Mac
¶

Student numbers: s4005637, s4005524
¶


Reading Data¶

The following code uploads the image classification data to the notebook work environment and unzips it

In [52]:
# Importing the image data.
import zipfile
with zipfile.ZipFile(r"Image_classification_data.zip", 'r') as zip_ref:
    zip_ref.extractall('./')

The dataset consists of the images ("patch_images") and a csv file ("data_labels_mainData"). The labels and the image paths are in the CSV file.

The following code randomly splits the data into train/val/test

The csv file contains the following data:

  • InstanceID (int)
  • PatientID (int)
  • ImageName (obj)
  • cellTypeName (obj)
  • cellType (int)
  • isCancerous (int)
In [54]:
# Visualising the csv file storing the links to the image data.
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd

data = pd.read_csv('./data_labels_mainData.csv')
dataExtra = pd.read_csv('./data_labels_extraData.csv')

data.head()
Out[54]:
InstanceID patientID ImageName cellTypeName cellType isCancerous
0 22405 1 22405.png fibroblast 0 0
1 22406 1 22406.png fibroblast 0 0
2 22407 1 22407.png fibroblast 0 0
3 22408 1 22408.png fibroblast 0 0
4 22409 1 22409.png fibroblast 0 0

EDA¶

In the code blocks below is the exploratory data analysis. The findings are discussed in a mark down at the end of the section.¶

In [58]:
# Checking the distribution of values in all columns.
data.describe()
Out[58]:
InstanceID patientID cellType isCancerous
count 9896.000000 9896.000000 9896.000000 9896.000000
mean 10193.880154 29.762025 1.501516 0.412187
std 6652.912660 17.486553 0.954867 0.492253
min 1.000000 1.000000 0.000000 0.000000
25% 4135.750000 14.000000 1.000000 0.000000
50% 9279.500000 26.000000 2.000000 0.000000
75% 16821.250000 47.000000 2.000000 1.000000
max 22444.000000 60.000000 3.000000 1.000000
In [60]:
# Checking for null/missing values and observing data types.
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9896 entries, 0 to 9895
Data columns (total 6 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   InstanceID    9896 non-null   int64 
 1   patientID     9896 non-null   int64 
 2   ImageName     9896 non-null   object
 3   cellTypeName  9896 non-null   object
 4   cellType      9896 non-null   int64 
 5   isCancerous   9896 non-null   int64 
dtypes: int64(4), object(2)
memory usage: 464.0+ KB
In [62]:
# Checking the distribution of cells for patient 4.
data1 = data[data["patientID"] == 4]
data1.head(20)
Out[62]:
InstanceID patientID ImageName cellTypeName cellType isCancerous
188 18589 4 18589.png fibroblast 0 0
189 18590 4 18590.png fibroblast 0 0
190 18591 4 18591.png fibroblast 0 0
191 18592 4 18592.png fibroblast 0 0
192 18593 4 18593.png fibroblast 0 0
193 18594 4 18594.png fibroblast 0 0
194 18595 4 18595.png fibroblast 0 0
195 18596 4 18596.png fibroblast 0 0
196 18597 4 18597.png fibroblast 0 0
197 18598 4 18598.png fibroblast 0 0
198 18599 4 18599.png fibroblast 0 0
199 18600 4 18600.png fibroblast 0 0
200 18601 4 18601.png fibroblast 0 0
201 18602 4 18602.png fibroblast 0 0
202 18603 4 18603.png fibroblast 0 0
203 18604 4 18604.png fibroblast 0 0
204 18605 4 18605.png fibroblast 0 0
205 18606 4 18606.png fibroblast 0 0
206 18607 4 18607.png fibroblast 0 0
207 18608 4 18608.png fibroblast 0 0
In [64]:
# Checking how that compares to the distribution of cells for patient 1.
data1 = data[data["patientID"] == 1]
data1.head(20)
Out[64]:
InstanceID patientID ImageName cellTypeName cellType isCancerous
0 22405 1 22405.png fibroblast 0 0
1 22406 1 22406.png fibroblast 0 0
2 22407 1 22407.png fibroblast 0 0
3 22408 1 22408.png fibroblast 0 0
4 22409 1 22409.png fibroblast 0 0
5 22410 1 22410.png fibroblast 0 0
6 22411 1 22411.png fibroblast 0 0
7 22412 1 22412.png fibroblast 0 0
8 22413 1 22413.png fibroblast 0 0
9 22414 1 22414.png fibroblast 0 0
10 22415 1 22415.png fibroblast 0 0
11 22417 1 22417.png inflammatory 1 0
12 22418 1 22418.png inflammatory 1 0
13 22419 1 22419.png inflammatory 1 0
14 22420 1 22420.png inflammatory 1 0
15 22421 1 22421.png inflammatory 1 0
16 22422 1 22422.png inflammatory 1 0
17 22423 1 22423.png others 3 0
18 22424 1 22424.png others 3 0
In [66]:
# Visualising the distribution of cell types across rows and identifying class imbalance.
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

plt.figure(figsize=(6,4)) # Plotting the distribution of isCancerous
sns.countplot(data=data, x='isCancerous')
plt.title("Distribution of isCancerous Labels")
plt.xlabel("isCancerous (0 = Non-cancerous, 1 = Cancerous)")
plt.ylabel("Count")
plt.show()

print(data['isCancerous'].value_counts()) # Printing the raw values.

plt.figure(figsize=(8,4)) # Plotting the distribution of cellType.
sns.countplot(data=data, x='cellType', order=data['cellType'].value_counts().index)
plt.title("Distribution of Cell Types")
plt.xlabel("Cell Type")
plt.ylabel("Count")
plt.xticks(rotation=45)
plt.show()

print(data['cellType'].value_counts()) # Plotting the raw values.
No description has been provided for this image
isCancerous
0    5817
1    4079
Name: count, dtype: int64
No description has been provided for this image
cellType
2    4079
1    2543
0    1888
3    1386
Name: count, dtype: int64
In [68]:
# Loading Images into the virtual enviroment for use.
import os
from PIL import Image
import matplotlib.pyplot as plt

image_dir = 'patch_images/'
image_files = os.listdir(image_dir)

print("Total images found:", len(image_files))
print("Sample filenames:", image_files[:5])
Total images found: 20280
Sample filenames: ['1.png', '10.png', '100.png', '1000.png', '10000.png']
In [70]:
# Visual inspection of sample images.
for i, file in enumerate(image_files[:5]):
    img_path = os.path.join(image_dir, file)
    img = Image.open(img_path)
    print(f"Image {i+1}: {file}, size: {img.size}, mode: {img.mode}")
    
    plt.imshow(img)
    plt.title(f"{file}")
    plt.axis('off')
    plt.show()
Image 1: 1.png, size: (27, 27), mode: RGB
No description has been provided for this image
Image 2: 10.png, size: (27, 27), mode: RGB
No description has been provided for this image
Image 3: 100.png, size: (27, 27), mode: RGB
No description has been provided for this image
Image 4: 1000.png, size: (27, 27), mode: RGB
No description has been provided for this image
Image 5: 10000.png, size: (27, 27), mode: RGB
No description has been provided for this image
In [72]:
# Checking for any missing images in dataset
import pandas as pd

df = pd.read_csv("data_labels_mainData.csv")
dfExtra = pd.read_csv("data_labels_extraData.csv")

csv_image_names = set(df['ImageName'])
csvExtra_image_names = set(dfExtra['ImageName'])
actual_image_names = set(image_files)

missing_in_folder = csv_image_names - actual_image_names
extra_in_folder = actual_image_names - csv_image_names

print("Missing images in folder:", len(missing_in_folder))
print("Extra images in folder:", len(extra_in_folder))
print("Total images:", len(actual_image_names))
print("Main data images:", len(csv_image_names))
print("Extra data images:", len(csvExtra_image_names))
Missing images in folder: 0
Extra images in folder: 10384
Total images: 20280
Main data images: 9896
Extra data images: 10384

=== Visual Inspection of Images ===¶

In [75]:
# Plot a few cancerous and non-cancerous images side by side.
fig, axs = plt.subplots(2, 5, figsize=(15, 6))
classes = [0, 1]

for row, label in enumerate(classes):
    subset = df[df['isCancerous'] == label].sample(5, random_state=42)
    
    for col, filename in enumerate(subset['ImageName']):
        img_path = os.path.join(image_dir, filename)
        img = Image.open(img_path)
        axs[row, col].imshow(img)
        axs[row, col].axis('off')
        axs[row, col].set_title(f"Label: {label}")

plt.suptitle("Visual Comparison: Cancerous vs Non-Cancerous Cells", fontsize=16)
plt.tight_layout()
plt.show()
No description has been provided for this image
In [77]:
# Visualising images by cell type to observe an obvious patterns.
cell_types = df['cellType'].unique()

fig, axs = plt.subplots(len(cell_types), 5, figsize=(15, 3 * len(cell_types)))

for i, ctype in enumerate(cell_types):
    samples = df[df['cellType'] == ctype].sample(5, random_state=42)
    
    for j, filename in enumerate(samples['ImageName']):
        img = Image.open(os.path.join(image_dir, filename))
        axs[i, j].imshow(img)
        axs[i, j].axis('off')
        axs[i, j].set_title(ctype)

plt.suptitle("Sample Images by Cell Type", fontsize=16)
plt.tight_layout()
plt.show()

# 0 = fibroblast
# 1 = inflammatory
# 2 = epithelial
# 3 = others
No description has been provided for this image

=== Image Statistics ===¶

In [ ]:
import numpy as np
from PIL import Image
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

# Store per-image RGB means and class labels
image_means = []
class_labels = []

for _, row in tqdm(df.iterrows(), total=len(df)):
    filename = row['ImageName']
    label = row['cellType']  # or 'isCancerous' if binary classification

    img_path = os.path.join(image_dir, filename)
    img = Image.open(img_path).convert('RGB')
    arr = np.array(img)

    # Compute mean RGB for this image
    mean_rgb = np.mean(arr.reshape(-1, 3), axis=0)

    image_means.append(mean_rgb)
    class_labels.append(label)

# Convert to arrays for processing
image_means = np.array(image_means)
class_labels = np.array(class_labels)

# Sort by class label
sorted_indices = np.argsort(class_labels)
image_means_sorted = image_means[sorted_indices]
class_labels_sorted = class_labels[sorted_indices]

# Plot
plt.figure(figsize=(12, 6))
plt.plot(image_means_sorted[:, 0], label='Red Mean', color='red', alpha=0.6)
plt.plot(image_means_sorted[:, 1], label='Green Mean', color='green', alpha=0.6)
plt.plot(image_means_sorted[:, 2], label='Blue Mean', color='blue', alpha=0.6)

# add class label bands for visualization
unique_labels, label_positions = np.unique(class_labels_sorted, return_index=True)
for idx, label in zip(label_positions, unique_labels):
    plt.axvline(x=idx, color='gray', linestyle='--', alpha=0.4)
    plt.text(idx, 260, f'Class {label}', rotation=45, verticalalignment='bottom')

plt.title('Per-Image RGB Mean Values (Grouped by Class)')
plt.xlabel('Image Index (Sorted by Class)')
plt.ylabel('Mean Pixel Value (0–255)')
plt.legend()
plt.tight_layout()
plt.show()
 56%|█████▌    | 5532/9896 [01:03<00:47, 92.23it/s] 
In [19]:
def get_class_pixel_stats(label):
    subset = df[df['isCancerous'] == label]
    pixels = []
    for filename in subset['ImageName']:
        img = Image.open(os.path.join(image_dir, filename)).convert('RGB')
        pixels.append(np.array(img))
    flat = np.stack(pixels).reshape(-1, 3)
    return np.mean(flat, axis=0), np.std(flat, axis=0)

mean_0, std_0 = get_class_pixel_stats(0)
mean_1, std_1 = get_class_pixel_stats(1)

print("Non-cancerous Mean:", mean_0, "| Std:", std_0)
print("Cancerous Mean:", mean_1, "| Std:", std_1)
Non-cancerous Mean: [201.96138913 157.94195246 206.81942219] | Std: [43.72277866 50.90752328 32.67643504]
Cancerous Mean: [181.59094206 138.98914915 203.17913829] | Std: [40.76874649 45.89625997 29.59015706]

EDA Observations¶

  • The dataset has 9,896 instances and 6 columns (InstanceID, patientID, ImageName, cellTypeName, cellType, and isCancerous)
  • No missing (null) values are present in any column
  • Class Distributions:
    • isCancerous; there is a moderate class imablance
      • 0 (non-cancerous): 5,817
      • 1 (cancerous): 4,079
      • Non-cancerous cells are about 42% more frequent
    • cellType; there is a significant class imbalance
      • Class 2 is the most frequent, followed by 1, 0, and 3:
      • Class 2: ~4,100 samples
      • Class 3: ~1,000 samples
  • There is clear variation between different cellType values in terms of color and texture, this could help in feature learning for cnn models
  • For pateintID == 4, all samples are similar, this suggest patients may have uniform cell characteristics which could introduce patient-specific biases into models if not split properly

Class Imbalance¶

By observing the distribution of the isCancerous and cellType columns, it is clear that some classes are more present in the dataset than others. For "isCancerous" this dataset contains many more class "0" samples than class "1". This immediately effected the performance of the intial models as it predicted only for the majority class even after regularisation.

This overfitting can be handled via either data augmentation or removing the some of the samples from the majority class. As we do not want to reduce our dataset size, data augmentation of the minority class will be applied in the pre-processing phase.

Data Splitting¶

In the code blocks below is the splitting of the dataset. The findings are discussed in a mark down at the end of the section.¶

In [8]:
# To avoid the spillage of patients across the data splits, we ensure that each patient's data is only included in a singular dataset.
unique_patients = data['patientID'].unique()

train_ids, temp_ids = train_test_split(unique_patients, test_size=0.4, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)

train_data = df[df['patientID'].isin(train_ids)]
val_data = df[df['patientID'].isin(val_ids)]
test_data = df[df['patientID'].isin(test_ids)]
In [9]:
# Here you can see that there are no overlapping paitents across data splits
print(train_data['patientID'].unique())
print(test_data['patientID'].unique())
print(val_data['patientID'].unique())
[ 2  3  8 10 11 12 15 16 17 19 21 22 23 24 25 26 27 28 29 30 31 33 36 38
 39 40 42 43 45 48 50 52 54 56 57 60]
[ 1  4  6  9 13 20 32 37 41 51 55 58]
[ 5  7 14 18 34 35 44 46 47 49 53 59]
In [10]:
# After splitting the data the distribution is visualised across the splits.
train_data_copy = train_data.copy()
train_data_copy['Split'] = 'Train'
val_data_copy = val_data.copy()
val_data_copy['Split'] = 'Validation'
test_data_copy = test_data.copy()
test_data_copy['Split'] = 'Test'
combined_data = pd.concat([train_data_copy, val_data_copy, test_data_copy])

plt.figure(figsize=(8, 5))
sns.countplot(data=combined_data, x='Split', hue='isCancerous', palette='Set2')

plt.title("isCancerous Class Distribution per Data Split")
plt.xlabel("Data Split")
plt.ylabel("Count")
plt.legend(title="isCancerous", labels=["Non-cancerous", "Cancerous"])
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 5))
sns.countplot(data=combined_data, x='Split', hue='cellType', palette='Set2')

plt.title("CellType Class Distribution per Data Split")
plt.xlabel("Data Split")
plt.ylabel("Count")
plt.legend(title="cellType", labels=["fibroblast", "inflammatory", "epithelial", "others"])
plt.tight_layout()
plt.show()

# 0 = fibroblast
# 1 = inflammatory
# 2 = epithelial
# 3 = others
No description has been provided for this image
No description has been provided for this image

Data Splitting Observations¶

For our dataset we chose to split our data into train, validation and test data splits instead of using a cross validation approach. We chose this approach as our dataset is of a significant size and cross-validation would be very computationally expensive. We would also not be able to control if a patient's data would leak across the validation and train set (using cross-validation), so we opted for using clear splits where we could be sure each patients data stayed within one split.

We further prevented data leakage by designating every patients data to only a single dataset. We did this because in the EDA patients demonstrated having more data in some cellTypes than others. If a patients data was in both the training and testing set it would skew the models performance results and we would not be able to accurately assess the performance of the model.

Data Preprocessing (Binary Classification)¶

In [47]:
from collections import Counter
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
from torch.utils.data import ConcatDataset, DataLoader

# This class handles the image data and creates the image path. It also applies transformations if needed.
class ColonCancerDataset(Dataset):
    def __init__(self, data, img_dir, transform=None):
        self.data = data
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.img_dir, row['ImageName'])
        label = torch.tensor(row['isCancerous'], dtype=torch.float)
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label
In [48]:
# This class expands upon the previous class by augmenting datasets that are passed in.
# This augmentation is done to handle class imbalance.
class AugmentedDataset(Dataset):
    def __init__(self, df, image_dir, transform, target_class, n_samples):
        self.df = df[df['isCancerous'] == target_class].reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        real_idx = idx % len(self.df)
        row = self.df.iloc[real_idx]
        image_path = os.path.join(self.image_dir, row['ImageName'])
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(row['isCancerous'], dtype=torch.long)
        return image, label
In [49]:
# These are the transformations applied to the datasets. They are called from within the classes.
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

val_test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
In [50]:
# Here we deal with the class imbalance by finding the difference between the majority and minority class.
class_counts = train_data['isCancerous'].value_counts()
majority_class = class_counts.idxmax()
minority_class = class_counts.idxmin()
n_to_augment = class_counts[majority_class] - class_counts[minority_class]

majority_df = train_data[train_data['isCancerous'] == majority_class] # Then we get a sample of the minority class.
minority_df = train_data[train_data['isCancerous'] == minority_class]
majority_df = majority_df.sample(n=len(minority_df), random_state=42)


# We create an instance of "AugmentedDataset" to increase the amount of the minority class.
augmented_minority_dataset = AugmentedDataset(train_data, image_dir="patch_images", transform=augment_transform, target_class=minority_class, n_samples=n_to_augment)
# Then we combine the original train data set and the augmented dataset.
train_base_dataset = ColonCancerDataset(pd.concat([majority_df, minority_df]), img_dir="patch_images", transform=train_transform)
balanced_train_dataset = ConcatDataset([train_base_dataset, augmented_minority_dataset])
# We create new instances of "ColonCancerDataset" for the validation and test dataset.
val_dataset = ColonCancerDataset(val_data, img_dir="patch_images", transform=val_test_transform)
test_dataset = ColonCancerDataset(test_data, img_dir="patch_images", transform=val_test_transform)

train_loader = DataLoader(balanced_train_dataset, batch_size=32, shuffle=True) # This is just intialising the loaders for the image data.
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print("Train size:", len(balanced_train_dataset)) # Here we print out the sizes of each dataset after augmentation.
print("Val size:", len(val_dataset))
print("Test size:", len(test_dataset))
Train size: 5787
Val size: 1782
Test size: 2327

Data Pre-Processing Justification¶

When creating our dataset classes we chose to use normalisation to help with the convergence speed during training and augmentation to deal with the class imbalance.

We normalised the images using transforms.Normalize, which scales the RGB pixels to be a range around 0. This is especially important for histopathelogy images as they can be taken in different clinics with different protocols, which could effect the lighting and quality of the images. By normalising the images we increase the generalisation ability of the model.

We performed data augmentation after identifying the class imbalance. This includes flipping the images horizontally and vertically, rotating the image and introducing colour jitter. The first two transformations are done to introduce spatial variety as colon cancer cells are not orientation specific. Colour jitter, which simulates different lighting conditions, will allow the model to learn the underlying patterns in the cells instead of relying on the RGB distribution of the image.

Model Training¶

== Binary Classification ==¶

1. Basic CNN Model¶

In [51]:
import torch.nn as nn
import torch.nn.functional as func

class baseCNN(nn.Module): # This first Model is a basic CNN Model.
    def __init__(self, dropout_probability=0.5):
        super(baseCNN, self).__init__()
        self.con1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2,2)
        self.con2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(32 * 6 * 6, 64)
        self.dropout = nn.Dropout(p=dropout_probability)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = self.pool(func.relu(self.con1(x)))
        x = self.pool(func.relu(self.con2(x)))
        x = x.view(x.size(0), -1)
        x = func.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x
In [52]:
from torch.nn.modules.loss import BCEWithLogitsLoss
from torch.optim import lr_scheduler as lr

model = baseCNN() # This is initialising the hyperparameters of the model.

neg_count = (train_data['isCancerous'] == 0).sum()
pos_count = (train_data['isCancerous'] == 1).sum()
pos_weight = torch.tensor(neg_count / float(pos_count), dtype=torch.float)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = lr.StepLR(optimizer, step_size=4, gamma=0.1)
In [53]:
# This is the training loop that we use for all our binary models.
def train_model(model, train_data, val_data, criterion, optim, scheduler, epochs=10):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_data:

            optim.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels.unsqueeze(1))
            loss.backward()
            optim.step()

            running_loss += loss.item()
            preds = (outputs > 0.5).float()
            correct += (preds.view(-1) == labels).sum().item()
            total += labels.size(0)

        history['loss'].append(running_loss / len(train_loader))
        history['accuracy'].append(correct / total)

        print(f"Epoch {epoch+1}/{epochs} - Training Loss: {running_loss/len(train_loader):.4f}")
        
        scheduler.step()

        validate_model(model, val_data)
In [54]:
# The training loop calls the validation loop.
def validate_model(model, val_data):
    model.eval()
    correct = 0
    total = 0
    val_loss = 0

    with torch.no_grad():
        for inputs, labels in val_data:
            outputs = model(inputs)
            loss = criterion(outputs, labels.unsqueeze(1))
            val_loss += loss.item()
            preds = (torch.sigmoid(outputs) > 0.5).int()
            correct += (preds.view(-1) == labels).sum().item()
            total += labels.size(0)

    history['val_loss'].append(val_loss / len(val_loader))
    history['val_accuracy'].append(correct / total)
    
    print(f"Validation Accuracy: {100 * correct / total:.2f}%")
In [55]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import matplotlib.pyplot as plt

# After training we use this testing method to see the results of our model on unseen data.
def test_model(model, test_data):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in test_data:
            outputs = model(inputs)
            preds = (outputs > 0.5).int()
            correct += (preds.view(-1) == labels).sum().item()
            total += labels.size(0)

            preds = (torch.sigmoid(outputs) > 0.5).int().view(-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Non-cancerous", "Cancerous"])
    disp.plot(cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.show()

    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=["0", "1"]))
In [56]:
def plot_learning_curve(train_loss, val_loss, train_metric, val_metric, metric_name='Accuracy'): # This is just a model for plotting the learning curve of each model.
    epochs = range(1, len(train_loss) + 1)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    ax1.plot(epochs, train_loss, 'b-', label='Training Loss')
    ax1.plot(epochs, val_loss, 'r--', label='Validation Loss')
    ax1.set_title('Loss Curve')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()

    ax2.plot(epochs, train_metric, 'b-', label=f'Training {metric_name}')
    ax2.plot(epochs, val_metric, 'r--', label=f'Validation {metric_name}')
    ax2.set_title(f'{metric_name} Curve')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel(metric_name)
    ax2.legend()

    plt.tight_layout()
    plt.show()
In [57]:
# Here the model is train, validated and evaluated against the test data.
history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=10)
test_model(model, test_loader)
plot_learning_curve(history['loss'], history['val_loss'], history['accuracy'], history['val_accuracy'], metric_name='Accuracy')
Epoch 1/10 - Training Loss: 0.5080
Validation Accuracy: 86.64%
Epoch 2/10 - Training Loss: 0.3391
Validation Accuracy: 87.21%
Epoch 3/10 - Training Loss: 0.3124
Validation Accuracy: 81.76%
Epoch 4/10 - Training Loss: 0.2885
Validation Accuracy: 84.29%
Epoch 5/10 - Training Loss: 0.2572
Validation Accuracy: 86.36%
Epoch 6/10 - Training Loss: 0.2492
Validation Accuracy: 87.09%
Epoch 7/10 - Training Loss: 0.2547
Validation Accuracy: 85.07%
Epoch 8/10 - Training Loss: 0.2504
Validation Accuracy: 86.53%
Epoch 9/10 - Training Loss: 0.2432
Validation Accuracy: 86.42%
Epoch 10/10 - Training Loss: 0.2478
Validation Accuracy: 85.97%
Test Accuracy: 89.69%
No description has been provided for this image
Classification Report:
              precision    recall  f1-score   support

           0       0.92      0.87      0.90      1221
           1       0.87      0.92      0.89      1106

    accuracy                           0.90      2327
   macro avg       0.90      0.90      0.90      2327
weighted avg       0.90      0.90      0.90      2327

No description has been provided for this image

Basic Model Observations¶

This model is a basic convolutional neural network, we started with this model to have a bench mark for performance.

The architecture has two convolutional layers. The first layer has 16 filters and the second layer increases this to 32, both with a kernel size of 3 and padding of 1 to keep the original dimensions. Each convolutional layer was followed by a ReLu activaction and max-pooling layer with a kernel size of 2. The output of the convolutional layers was flattened and passed through a fully connected layer with a ReLu activaction, followed by a dropout layer(of 50%) to reduc overfitting. The final layer of the model produced a raw logit as a prediction.

We chose this simplistic CNN because we wanted to first explore CNN's ability to learn spatial patterns in image data. For our histopathelogy dataset, the patterns in texture and spatial structure were important in differentiating between cancerous and non-cancerous cells. So understanding the performance of a simplistic model allowed us to gauge the performance of other models.

Although we already handled the class imbalance in pre-processing, we added a pos_weight parameter and passed it into the BCEWithLogitsLoss function. By adjusting the loss function to penalise false negatives more heavily, we made the model more sensitive to class "1". Therefor further compensating for the class imbalance.

To optimise the model we used Adam optimiser with a learning rate of 0.001 and a weight decay of 1e-5 to introduce L2 regularisation. We also implemented a learning rate scheduler which reduced the learning rate by a factor of 0.1 every 4 epochs which helps the model to fine-tune weights in later epochs.

Loss & Accuracy¶

The loss between each epoch for the training loop demonstrates a good fit of the model as it initally decreases and then stabilises. The validation loss however is irregular. It fluctuates between 0.40 and 0.35. Similarly, the validation accuracy fluctuates between 0.85 and 0.89. This strange behaviour would typically be indicitive of overfitting or too high of a learning rate. However, from the test set confusion matrix and "typical" performance of the training set through each epoch we can outrule overfitting.

Classification Matrix¶

From the classification matrix we can see that the model is not overfitting on one class. The classification report also tells us the model is sensitive to the cancerous class, with a recall of 0.90 and precision of 0.88. Therefor the class imbalance has been properly handled with the augmentation and penalisation of the majority class.

2. Residual CNN¶

In [58]:
class ResidualBlock(nn.Module): # This is a residual block that is used in our residual model.
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.same_channels = in_channels == out_channels

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        if in_channels != out_channels:
            self.shortcut = nn.Sequential(nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels)))
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        out += identity

        return self.relu(out)
In [59]:
from torch import Tensor # This is the main model class for our residual model.
class ResidualModel(nn.Module):
    def __init__(self, in_channels: int =1):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, 16, kernel_size=3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()

        self.layer1 = nn.Sequential(ResidualBlock(16,16), ResidualBlock(16,16), nn.AvgPool2d(kernel_size=2))
        self.layer2 = nn.Sequential(ResidualBlock(16,32), ResidualBlock(32,32), nn.AvgPool2d(kernel_size=2))

        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(0.5)
        self.linear = nn.Linear(32 * 6 * 6, 1)

    def forward(self, x: Tensor):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)

        x = self.flatten(x)
        x = self.dropout(x)
        x = self.linear(x)

        return x
In [60]:
model = ResidualModel(in_channels=3) # Here the model is trained, valiadted and tested. The hyperparameters are also initalised.

neg_count = (train_data['isCancerous'] == 0).sum()
pos_count = (train_data['isCancerous'] == 1).sum()
pos_weight = torch.tensor(neg_count / float(pos_count), dtype=torch.float)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = lr.StepLR(optimizer, step_size=15, gamma=0.1)

history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=10)
test_model(model, test_loader)
plot_learning_curve(history['loss'], history['val_loss'], history['accuracy'], history['val_accuracy'], metric_name='Accuracy')
Epoch 1/10 - Training Loss: 0.3812
Validation Accuracy: 82.72%
Epoch 2/10 - Training Loss: 0.2981
Validation Accuracy: 85.69%
Epoch 3/10 - Training Loss: 0.2620
Validation Accuracy: 86.08%
Epoch 4/10 - Training Loss: 0.2479
Validation Accuracy: 87.26%
Epoch 5/10 - Training Loss: 0.2590
Validation Accuracy: 87.82%
Epoch 6/10 - Training Loss: 0.2496
Validation Accuracy: 88.27%
Epoch 7/10 - Training Loss: 0.2494
Validation Accuracy: 86.64%
Epoch 8/10 - Training Loss: 0.2379
Validation Accuracy: 86.70%
Epoch 9/10 - Training Loss: 0.2225
Validation Accuracy: 88.44%
Epoch 10/10 - Training Loss: 0.2311
Validation Accuracy: 87.77%
Test Accuracy: 89.17%
No description has been provided for this image
Classification Report:
              precision    recall  f1-score   support

           0       0.89      0.89      0.89      1221
           1       0.88      0.88      0.88      1106

    accuracy                           0.89      2327
   macro avg       0.89      0.89      0.89      2327
weighted avg       0.89      0.89      0.89      2327

No description has been provided for this image

Residual Model Observations¶

We chose a residual convolutional neural network to try find deeper hierarchal features in the image data. The architecture uses residual blocks, which allow the model to learn identity mapping and reduce the vanishing gradient problem. In the case of histopathological classification, where the cancerous variations of cells can be subtle, this function is particularly useful.

The model begins with a convolutional layer of 16 filters, then batch normalisation and ReLu activation. The output gets passed through two residual stages. The first includes two residual blocks with 16 channels and a pooling layer to reduce the dimensions of the output. The second stage has four residual blocks with 32 total channels and then another pooling layer to reduce dimensionality again. The output of this is flattened, passed through a dropout layer of 0.5 before being classified with a linear layer.

The same class imbalance methods were used across all three models, hence weighted binary cross-entropy loss was present in the training loop. The residual model was also optimised with Adam optimiser like the other models.

Loss & Accuracy¶

Similarly to the baseCNN model, the training loop's decreases slightly and then appears to stabilise. The accuracy of the training loop increases slightly with each epoch, although its accuracy starts at 0.86 which is higher than BaseCNN's 0.85. Although it does not increase as drastically as BaseCNN, the biggest improvement is the increased stability of the validation loop's accuracy and loss. There is a significant decrease/inscrease for the respective loss and accuracy and then a fluctuation of values, however they remain within a close range of each other. This demonstrates that the validation sets atypical performance in baseCNN may have to do with the simplistic nature of the model.

Classification Matrix¶

The model's accuracy of 87%, with a slightly stronger accuracy for the Cancerous class compared to the nonCancerous. It had high precision of 94% for nonCancerous cases and recall of 94% for cancerous cases. This displays that most cancer cases were identified correctly. The F1-scores were almost identical, 0.87 for nonCancerous and 0.88 for cancerous, indicating that the model generalises well for both classes.

3. DenseNet¶

In [61]:
from torchvision.models.densenet import DenseNet # This is the model class for our DenseNet model.

class DenseModel(nn.Module):
    def __init__(self, num_classes=1, growth_rate=32, block_config=(6,6,6), bin_size=4, dropout=0.0):
        super(DenseModel, self).__init__()

        self.features = DenseNet(num_init_features=32, num_classes=num_classes, growth_rate=growth_rate, block_config=block_config, bn_size=bin_size, drop_rate=dropout).features

        self.features.conv = nn.Conv2d(344, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.features.pool = nn.Identity()

        with torch.no_grad():
            self.eval()
            dummy = torch.randn(1, 3, 27, 27)
            out = self.features(dummy)
            out = func.relu(out)
            out = func.adaptive_avg_pool2d(out, (1, 1)).view(1, -1)
            feature_size = out.shape[1]
            self.train()

        self.classifier = nn.Linear(feature_size, num_classes)

    def forward(self, x):
        features = self.features(x)
        out = func.relu(features, inplace=True)
        out = func.adaptive_avg_pool2d(out, (1, 1)).view(x.size(0), -1)
        out = self.classifier(out)
        return out
In [62]:
model = DenseModel() # Here the model is trained, valiadted and tested. The hyperparameters are also initalised.

neg_count = (train_data['isCancerous'] == 0).sum()
pos_count = (train_data['isCancerous'] == 1).sum()
pos_weight = torch.tensor(neg_count / float(pos_count), dtype=torch.float)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = lr.StepLR(optimizer, step_size=15, gamma=0.1)

history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=10)
test_model(model, test_loader)
plot_learning_curve(history['loss'], history['val_loss'], history['accuracy'], history['val_accuracy'], metric_name='Accuracy')
Epoch 1/10 - Training Loss: 0.3944
Validation Accuracy: 78.79%
Epoch 2/10 - Training Loss: 0.3027
Validation Accuracy: 85.41%
Epoch 3/10 - Training Loss: 0.2854
Validation Accuracy: 85.75%
Epoch 4/10 - Training Loss: 0.3001
Validation Accuracy: 88.66%
Epoch 5/10 - Training Loss: 0.2712
Validation Accuracy: 81.31%
Epoch 6/10 - Training Loss: 0.2587
Validation Accuracy: 88.33%
Epoch 7/10 - Training Loss: 0.2544
Validation Accuracy: 89.28%
Epoch 8/10 - Training Loss: 0.2442
Validation Accuracy: 88.16%
Epoch 9/10 - Training Loss: 0.2491
Validation Accuracy: 84.68%
Epoch 10/10 - Training Loss: 0.2327
Validation Accuracy: 87.77%
Test Accuracy: 90.76%
No description has been provided for this image
Classification Report:
              precision    recall  f1-score   support

           0       0.91      0.92      0.91      1221
           1       0.91      0.90      0.90      1106

    accuracy                           0.91      2327
   macro avg       0.91      0.91      0.91      2327
weighted avg       0.91      0.91      0.91      2327

No description has been provided for this image

DenseNet Model Observations¶

We chose DenseNet architecture as it has been extensively used in biomedical image analysis, specifically for histopathological image classification. Some examples include breast cancer detection and colorectal cancer grading. Its effectiveness comes from its dense connectivity pattern as each layer receives input from the preceding layers. This encourages feature reuse and results in compact models with less parameters. These can be especially useful in histopathology where datasets can be limited.

The architecture of this model consists of three dense blocks with six layers each, a growth rate of 32 and no dropout in the dense layers to retain as much information as possible. We customised the feature extractor by replacing the final convolution with a smaller 3x3 convolution that outputted 32 feature maps to account for the 27x27 input size. We removed the pooling and an adaptive average pooling was used in the forward pass to reduce dimensionality. To classify we used a linear layer that mapped the extracted features to a logit.

The hyperparameters were set to the same values as the previous models to test the improvements based on the architecture of the model rather than its parameters.

Loss & Accuracy¶

The training loops' loss and accuracy demonstrated a good fit by decreasing initially and then stabilising. The accuracy in each epoch increased not as drastically as the BaseCNN model but more than the residual model. The validation loss showed an overall decreasing trend, however the validation accuracy fluctuated similarly to the BaseCNN model. It's range stayed within the 0.80 and 0.88 percentile, which is not siginificantly concerning.

Classification Matrix¶

The model's overall accuracy of 90% indicates that like the previous models, this model performs well on both classes. For the nonCancerous class there was a precision of 0.92 and a recall of 0.89. Therefor although most predicitions were correct, a few false nonCancerous predictions were made. The Cancerous class had a lower precision of 0.88 but a higher recall of 0.92. So while the model could identify most Cancerous cases, it did make a few false positive predictions also.

== Multiclass Classification ==¶

Data Preprocessing (Multiclass Classification)¶

In [40]:
from collections import Counter
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
from torch.utils.data import ConcatDataset, DataLoader

# This class handles the image data and creates the image path. It also applies transformations if needed.
class CellTypeDataset(Dataset):
    def __init__(self, data, img_dir, transform=None):
        self.data = data
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.img_dir, row['ImageName'])
        label = torch.tensor(row['cellType'], dtype=torch.long) # long allows for multiclass classification
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label
In [41]:
class UnlabeledImageDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = os.path.join(self.img_dir, row['ImageName'])
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image  

    def __len__(self):
        return len(self.df)
In [42]:
# This class expands upon the previous class by augmenting datasets that are passed in.
# This augmentation is done to handle class imbalance.
class AugmentedMultiClassDataset(Dataset):
    def __init__(self, df, image_dir, transform, target_class, n_samples):
        self.df = df[df['cellType'] == target_class].reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
        self.n_samples = n_samples
        self.label = target_class
        self.label_dtype = torch.long

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        real_idx = idx % len(self.df)
        row = self.df.iloc[real_idx]
        image_path = os.path.join(self.image_dir, row['ImageName'])
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(self.label, dtype=self.label_dtype)
        return image, label
In [43]:
# These are the transformations applied to the datasets. They are called from within the classes.
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

val_test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
In [44]:
# Dealing with class imbalances and finding the largest class
class_counts = train_data['cellType'].value_counts()
max_count = class_counts.max()

# Store datasets to combine later
base_datasets = []
augmented_datasets = []

for cell_type, count in class_counts.items():
    df_class = train_data[train_data['cellType'] == cell_type]

    # 1. If underrepresented -> augment
    if count < max_count:
        n_to_augment = max_count - count

        # Add original data
        base_datasets.append(CellTypeDataset(
            df_class,
            img_dir="patch_images",
            transform=train_transform
        ))

        # Add augmented synthetic samples
        augmented = AugmentedMultiClassDataset(
            df=train_data,
            image_dir="patch_images",
            transform=augment_transform,
            target_class=cell_type,
            n_samples=n_to_augment
        )
        augmented_datasets.append(augmented)

    # 2. If already at max count -> use directly
    else:
        base_datasets.append(CellTypeDataset(
            df_class.sample(n=max_count, random_state=42),
            img_dir="patch_images",
            transform=train_transform
        ))

# Combine all datasets
unlabeled_dataset = UnlabeledImageDataset(dfExtra, img_dir="patch_images", transform=val_test_transform)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=64, shuffle=False)

balanced_train_dataset = ConcatDataset(base_datasets + augmented_datasets)

val_dataset = CellTypeDataset(val_data, img_dir="patch_images", transform=val_test_transform)
test_dataset = CellTypeDataset(test_data, img_dir="patch_images", transform=val_test_transform)

train_loader = DataLoader(balanced_train_dataset, batch_size=32, shuffle=True) # This is just intialising the loaders for the image data.
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


print("Train size:", len(balanced_train_dataset)) # Here we print out the sizes of each dataset after augmentation.
print("Val size:", len(val_dataset))
print("Test size:", len(test_dataset))
print("Extra size:", len(unlabeled_dataset))

from collections import Counter

# Count how many samples of each class exist in the final dataset
label_counter = Counter()

for i in range(len(balanced_train_dataset)):
    _, label = balanced_train_dataset[i]  
    label_counter[int(label)] += 1

print("Balanced class distribution:", dict(label_counter))
Train size: 8720
Val size: 1782
Test size: 2327
Extra size: 10384
Balanced class distribution: {2: 2180, 1: 2180, 0: 2180, 3: 2180}

Data Pre-Processing Justification¶

The preprocessing pipeline for multiclass classification of has been designed to ensure balanced class representation. The dataset initially suffers from class imbalance, where some cell types dominate in frequency while others are underrepresented. To address this, we used a combination of base datasets and synthetic augmentation via the AugmentedMultiClassDataset class. This guarantees that all four classes are equally represented in the final training dataset, which helps the model learn equally from all categories and prevents bias toward majority classes.

For underrepresented classes, additional synthetic samples are generated using aggressive transformations such as flipping, rotation, color jitter, and affine transformations. These augmentations simulate variations commonly seen in medical imaging, helping the model become more invariant to such changes. A separate transformation pipeline is used for validation and test datasets to ensure evaluation on unaltered images, maintaining consistency for performance assessment.

The use of ConcatDataset merges real and augmented data into a single unified training set, and all images are normalized to a standard RGB mean and std range which facilitates convergence in CNN training.

1. Basic CNN Model¶

In [45]:
import torch.nn as nn
import torch.nn.functional as func

class baseCNN_MultiClass(nn.Module): 
    def __init__(self, num_classes=4, dropout_probability=0.5):
        super(baseCNN_MultiClass, self).__init__()
        self.con1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        
        self.con2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        
        self.con3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        
        self.con4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)

        self.pool = nn.MaxPool2d(2,2)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.dropout = nn.Dropout(p=dropout_probability)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(func.relu(self.bn1(self.con1(x))))
        x = self.pool(func.relu(self.bn2(self.con2(x))))
        x = self.pool(func.relu(self.bn3(self.con3(x))))
        x = self.pool(func.relu(self.bn4(self.con4(x))))
        
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        
        x = self.dropout(x)
        x = self.fc(x)

        return x
In [ ]:
from torch.nn.modules.loss import BCEWithLogitsLoss
from torch.optim import lr_scheduler as lr
from sklearn.utils.class_weight import compute_class_weight

model = baseCNN_MultiClass(num_classes=4)

classes = np.unique(train_data['cellType'].values)
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=train_data['cellType'].values)
class_weights = torch.tensor(class_weights, dtype=torch.float)

criterion = FocalLoss(gamma=2.0, weight=class_weights)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = lr.StepLR(optimizer, step_size=4, gamma=0.1)
In [292]:
# This is the training loop that we use for all our multiclass models.
def train_model_multiclass(model, train_data, val_data, criterion, optim, scheduler, epochs=10, history=None):
    if history is None:
        history = {
            'loss': [],
            'val_loss': [],
            'accuracy': [],
            'val_accuracy': []
        }

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            labels = labels.long()
            optim.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optim.step()

            running_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / len(train_loader)
        epoch_acc  = correct / total

        history['loss'].append(epoch_loss)
        history['accuracy'].append(epoch_acc)

        print(f"Epoch {epoch+1}/{epochs} - Training Loss: {running_loss/len(train_loader):.4f}")
        
        scheduler.step()

        validate_model_multiclass(model, val_loader, criterion, history)
In [294]:
def validate_model_multiclass(model, val_loader, criterion, history):
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0

    with torch.no_grad():
        for inputs, labels in val_loader:
            labels = labels.long()
            outputs = model(inputs)  # shape: (batch_size, num_classes)
            loss = criterion(outputs, labels)  # labels: (batch_size,)
            val_loss += loss.item()

            # Get predicted class by taking the argmax
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = val_loss / len(val_loader)
    accuracy = correct / total

    if history is not None:
        history['val_loss'].append(avg_loss)
        history['val_accuracy'].append(accuracy)

    print(f"Validation Loss: {avg_loss:.4f} - Accuracy: {100 * accuracy:.2f}%")
In [ ]:
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

def test_model_multiclass(model, test_loader, class_names=None):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            labels = labels.long()
            outputs = model(inputs)  # shape: (batch_size, num_classes)

            # Get predicted class (index of max logit)
            preds = torch.argmax(outputs, dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Accuracy
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")

    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    labels_to_display = class_names if class_names else sorted(set(all_labels))

    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels_to_display)
    disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
    plt.title("Confusion Matrix (Multi-Class)")
    plt.tight_layout()
    plt.show()

    print("\n📊 Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))
In [298]:
def plot_learning_curve(train_loss, val_loss, train_metric, val_metric, metric_name='Accuracy'): # This is just a model for plotting the learning curve of each model.
    epochs = range(1, len(train_loss) + 1)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    ax1.plot(epochs, train_loss, 'b-', label='Training Loss')
    ax1.plot(epochs, val_loss, 'r--', label='Validation Loss')
    ax1.set_title('Loss Curve')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()

    ax2.plot(epochs, train_metric, 'b-', label=f'Training {metric_name}')
    ax2.plot(epochs, val_metric, 'r--', label=f'Validation {metric_name}')
    ax2.set_title(f'{metric_name} Curve')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel(metric_name)
    ax2.legend()

    plt.tight_layout()
    plt.show()
In [300]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, weight=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight  
        self.reduction = reduction

    def forward(self, inputs, targets):
        log_probs = F.log_softmax(inputs, dim=1)         # log(p)
        probs = torch.exp(log_probs)                     # p

        targets_one_hot = F.one_hot(targets, num_classes=inputs.size(1)).float()
        focal_weight = (1 - probs) ** self.gamma         # (1 - p)^gamma

        loss = -targets_one_hot * focal_weight * log_probs

        if self.weight is not None:
            loss = loss * self.weight.unsqueeze(0)       # apply class weights

        loss = loss.sum(dim=1)  # sum over classes

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss
In [143]:
# Initialize training history
history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}

# Train the model
train_model_multiclass(
    model,
    train_loader,
    val_loader,
    criterion,      
    optimizer,
    scheduler,
    epochs=10,
    history=history
)

# Test the model
class_names = ["Fibroblast", "Inflammatory", "Epithelial", "Others"]
test_model_multiclass(model, test_loader, class_names=class_names)

# Plot learning curves
plot_learning_curve(
    history['loss'],
    history['val_loss'],
    history['accuracy'],
    history['val_accuracy'],
    metric_name='Accuracy'
)
Epoch 1/10 - Training Loss: 0.5285
Validation Loss: 0.4507 - Accuracy: 69.58%
Epoch 2/10 - Training Loss: 0.3876
Validation Loss: 0.4107 - Accuracy: 68.18%
Epoch 3/10 - Training Loss: 0.3604
Validation Loss: 0.3790 - Accuracy: 71.10%
Epoch 4/10 - Training Loss: 0.3380
Validation Loss: 0.3982 - Accuracy: 66.05%
Epoch 5/10 - Training Loss: 0.3142
Validation Loss: 0.3745 - Accuracy: 70.31%
Epoch 6/10 - Training Loss: 0.3081
Validation Loss: 0.3697 - Accuracy: 70.54%
Epoch 7/10 - Training Loss: 0.2970
Validation Loss: 0.3787 - Accuracy: 70.37%
Epoch 8/10 - Training Loss: 0.2971
Validation Loss: 0.3919 - Accuracy: 69.36%
Epoch 9/10 - Training Loss: 0.2902
Validation Loss: 0.3785 - Accuracy: 70.15%
Epoch 10/10 - Training Loss: 0.2923
Validation Loss: 0.3823 - Accuracy: 69.53%
Test Accuracy: 73.96%
No description has been provided for this image
📊 Classification Report:
              precision    recall  f1-score   support

  Fibroblast       0.60      0.68      0.64       343
Inflammatory       0.67      0.72      0.70       607
  Epithelial       0.94      0.85      0.89      1106
      Others       0.38      0.39      0.39       271

    accuracy                           0.74      2327
   macro avg       0.65      0.66      0.65      2327
weighted avg       0.75      0.74      0.74      2327

No description has been provided for this image

BaseCNN Model Observations¶

The chosen architecture for this task is a lightweight CNN consisting of four convolutional layers, each followed by batch normalization, ReLU activation, and max pooling. A final global average pooling layer and dropout regularization precede the output classification layer. This structure is well-suited for 27x27 pixel histopathological image patches, capturing hierarchical features without introducing excessive model complexity. The use of batch normalization helps stabilize training, while adaptive pooling ensures consistent dimensionality regardless of input size. This design is intentionally simple to test the effectiveness of preprocessing and data balancing strategies before experimenting with deeper models.

Hyperparameters and training strategies have been carefully selected to complement the model architecture. Focal Loss is used instead of CrossEntropyLoss to address class imbalance by focusing training on harder-to-classify examples. The loss is weighted using computed class_weights, ensuring equal learning pressure across the four cell type categories. Additionally, a StepLR learning rate scheduler decays the learning rate every 4 epochs to promote convergence. The optimizer used is AdamW, a variant that integrates weight decay for regularization, minimizing overfitting risk. Together, these choices provide a well-tuned training regime for multiclass cell classification.

Loss & Accuracy¶

The training and validation loss curves both show a downward trend across the 10 training epochs, suggesting successful convergence. Training accuracy steadily increases and reaches around 74%, which is closely mirrored by the test accuracy of 73.96%. Validation accuracy fluctuates slightly but remains stable within the range of 68–71%, indicating that the model generalizes fairly well without major overfitting. The relatively small gap between training and validation accuracy suggests a good model fit for this classification task.

Classification Matrix¶

The confusion matrix reveals that the model performs best on the Epithelial class (class 2), achieving a precision of 0.94 and recall of 0.85. However, it struggles most with the "Others" class (class 3), which has the lowest F1-score (0.39), indicating misclassification with similar-looking cell types. "Fibroblast" and "Inflammatory" classes show moderate performance, with F1-scores of 0.64 and 0.70 respectively. These disparities likely stem from inter-class visual similarities and remaining class imbalance in the test set. The model has a macro average F1-score of 0.65 and a weighted average of 0.74, showing it performs reasonably well across the board, though further improvement is needed for the minority classes.

2. Custom Small Image CNN Model¶

In [302]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class CustomSmallImageCNN(nn.Module):
    def __init__(self, num_classes=4, dropout_p=0.5):
        super(CustomSmallImageCNN, self).__init__()

        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # 27x27 → 13x13
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # 13x13 → 6x6
        )

        self.conv_block3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # 6x6 → 3x3
        )

        self.conv_block4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))  # 3x3 → 1x1
        )

        self.dropout = nn.Dropout(dropout_p)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)

        x = x.view(x.size(0), -1)  # Flatten (B, 256, 1, 1) → (B, 256)
        x = self.dropout(x)
        x = self.fc(x)
        return x
In [197]:
# Instantiate Custom CNN model
model = CustomSmallImageCNN(num_classes=4)

criterion = FocalLoss(gamma=2.0, weight=class_weights)

# Use same optimizer and scheduler setup as before
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)

# Initialize training history
history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}

# Train the EfficientNet model
train_model_multiclass(
    model,
    train_loader,
    val_loader,
    criterion,      
    optimizer,
    scheduler,
    epochs=10,
    history=history
)

# Evaluate on test set
class_names = ["Fibroblast", "Inflammatory", "Epithelial", "Others"]
test_model_multiclass(model, test_loader, class_names=class_names)

# Plot learning curves
plot_learning_curve(
    history['loss'],
    history['val_loss'],
    history['accuracy'],
    history['val_accuracy'],
    metric_name='Accuracy'
)
Epoch 1/10 - Training Loss: 0.4354
Validation Loss: 0.3905 - Accuracy: 66.55%
Epoch 2/10 - Training Loss: 0.3648
Validation Loss: 0.4260 - Accuracy: 67.96%
Epoch 3/10 - Training Loss: 0.3369
Validation Loss: 0.3837 - Accuracy: 70.65%
Epoch 4/10 - Training Loss: 0.3250
Validation Loss: 0.4218 - Accuracy: 64.59%
Epoch 5/10 - Training Loss: 0.2891
Validation Loss: 0.3794 - Accuracy: 70.09%
Epoch 6/10 - Training Loss: 0.2838
Validation Loss: 0.3700 - Accuracy: 71.27%
Epoch 7/10 - Training Loss: 0.2774
Validation Loss: 0.3710 - Accuracy: 70.09%
Epoch 8/10 - Training Loss: 0.2713
Validation Loss: 0.3615 - Accuracy: 70.93%
Epoch 9/10 - Training Loss: 0.2656
Validation Loss: 0.3668 - Accuracy: 71.21%
Epoch 10/10 - Training Loss: 0.2679
Validation Loss: 0.3665 - Accuracy: 71.38%
Test Accuracy: 76.28%
No description has been provided for this image
📊 Classification Report:
              precision    recall  f1-score   support

  Fibroblast       0.63      0.70      0.66       343
Inflammatory       0.70      0.74      0.72       607
  Epithelial       0.92      0.86      0.89      1106
      Others       0.49      0.50      0.50       271

    accuracy                           0.76      2327
   macro avg       0.69      0.70      0.69      2327
weighted avg       0.77      0.76      0.77      2327

No description has been provided for this image

Custom Small Image CNN Model Observations¶

This CNN architecture was specifically tailored for small histopathological images (27×27 pixels). It consists of four convolutional blocks using progressively increasing filter sizes from 32 to 256. Each block includes batch normalization, ReLU activation, and max pooling to reduce spatial dimensions. The final block uses adaptive average pooling to compress the feature map to 1×1 before classification, ensuring consistent output shape regardless of input resolution. This modular, sequential block design promotes clear gradient flow, while dropout before the fully connected layer combats overfitting. The model is compact yet expressive enough to capture cell type morphology in histology images.

The model was trained with Focal Loss to focus on hard-to-classify samples, using the same class weights derived from the cellType distribution. A learning rate scheduler was applied to reduce the learning rate over time and improve convergence stability. Training history was tracked over 10 epochs, and evaluation included classification reports, confusion matrices, and learning curves. The model achieved a test accuracy of 73.36%, indicating solid performance comparable to the base CNN, despite architectural differences.

Loss & Accuracy¶

The training loss shows a steep and consistent decline, suggesting efficient learning. Validation loss stabilizes after a few epochs, reflecting that the model has reached a reasonable generalization point. The training accuracy increases smoothly, ending just above 74%, while validation accuracy ranges from about 66% to 71%, showing a minor but stable gap. This implies the model fits the training data well and generalizes modestly, with no signs of overfitting despite the limited image size and shallow depth. The training curves support the conclusion that the model’s learning dynamics are well-regulated and effective.

Classification Matrix¶

The classification matrix shows the model performs best on the Epithelial class (class 2), with a precision of 0.92 and recall of 0.86, demonstrating strong identification of this class. The Inflammatory class (class 1) also sees strong recall (0.74) and decent precision (0.72). The Fibroblast class (class 0) has slightly lower precision (0.63) and recall (0.70), while the "Others" class (class 3) again remains the hardest to classify, with the lowest F1-score of 0.59. The model tends to misclassify "Others" across other categories, likely due to visual ambiguity. Still, a macro F1-score of 0.69 and a weighted F1-score of 0.77 indicate a balanced performance across all classes.

=== Extra Images Classification ===¶

In [304]:
base_model = CustomSmallImageCNN(num_classes=4)

criterion = FocalLoss(weight=class_weights)
optimizer = torch.optim.Adam(base_model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)

history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}

train_loader = DataLoader(balanced_train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

train_model_multiclass(
    base_model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    epochs=10,
    history=history
)
Epoch 1/10 - Training Loss: 0.4373
Validation Loss: 0.7971 - Accuracy: 61.95%
Epoch 2/10 - Training Loss: 0.3502
Validation Loss: 0.5573 - Accuracy: 65.10%
Epoch 3/10 - Training Loss: 0.3294
Validation Loss: 0.4161 - Accuracy: 70.43%
Epoch 4/10 - Training Loss: 0.3074
Validation Loss: 0.4217 - Accuracy: 66.55%
Epoch 5/10 - Training Loss: 0.2871
Validation Loss: 0.3748 - Accuracy: 69.87%
Epoch 6/10 - Training Loss: 0.2746
Validation Loss: 0.3794 - Accuracy: 71.55%
Epoch 7/10 - Training Loss: 0.2665
Validation Loss: 0.3692 - Accuracy: 70.48%
Epoch 8/10 - Training Loss: 0.2597
Validation Loss: 0.3630 - Accuracy: 71.27%
Epoch 9/10 - Training Loss: 0.2597
Validation Loss: 0.3710 - Accuracy: 71.72%
Epoch 10/10 - Training Loss: 0.2591
Validation Loss: 0.3751 - Accuracy: 71.77%
In [314]:
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=64, shuffle=False)

pseudo_inputs = []
pseudo_labels = []

base_model.eval()
with torch.no_grad():
    for inputs in unlabeled_loader:
        inputs = inputs
        outputs = base_model(inputs)
        probs = torch.softmax(outputs, dim=1)
        confidence, preds = probs.max(dim=1)
        
        mask = confidence > 0.88
        pseudo_inputs.append(inputs[mask])
        pseudo_labels.append(preds[mask])

pseudo_images = torch.cat(pseudo_inputs)
pseudo_targets = torch.cat(pseudo_labels)

print(f"Pseudo-labeled {len(pseudo_targets)} of {len(unlabeled_dataset)} images.")
Pseudo-labeled 1687 of 10384 images.
In [316]:
pseudo_dataset = torch.utils.data.TensorDataset(pseudo_images, pseudo_targets)
combined_dataset = torch.utils.data.ConcatDataset([balanced_train_dataset, pseudo_dataset])
combined_loader = DataLoader(combined_dataset, batch_size=64, shuffle=True)
In [318]:
new_model = CustomSmallImageCNN(num_classes=4)

criterion = FocalLoss(gamma=2.0, weight=class_weights)
optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)

combined_history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}

train_model_multiclass(
    new_model,
    combined_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    epochs=10,
    history=combined_history
)

test_model_multiclass(new_model, test_loader, class_names=["Fibroblast", "Inflammatory", "Epithelial", "Others"])

plot_learning_curve(
    combined_history['loss'],
    combined_history['val_loss'],
    combined_history['accuracy'],
    combined_history['val_accuracy'],
    metric_name='Accuracy'
)
Epoch 1/10 - Training Loss: 0.4315
Validation Loss: 0.4462 - Accuracy: 68.57%
Epoch 2/10 - Training Loss: 0.3574
Validation Loss: 0.5365 - Accuracy: 57.86%
Epoch 3/10 - Training Loss: 0.3249
Validation Loss: 0.4051 - Accuracy: 63.92%
Epoch 4/10 - Training Loss: 0.3071
Validation Loss: 0.4241 - Accuracy: 63.80%
Epoch 5/10 - Training Loss: 0.2814
Validation Loss: 0.3772 - Accuracy: 69.92%
Epoch 6/10 - Training Loss: 0.2728
Validation Loss: 0.3742 - Accuracy: 70.54%
Epoch 7/10 - Training Loss: 0.2702
Validation Loss: 0.3663 - Accuracy: 71.04%
Epoch 8/10 - Training Loss: 0.2661
Validation Loss: 0.3696 - Accuracy: 70.15%
Epoch 9/10 - Training Loss: 0.2609
Validation Loss: 0.3679 - Accuracy: 70.54%
Epoch 10/10 - Training Loss: 0.2610
Validation Loss: 0.3708 - Accuracy: 70.82%
Test Accuracy: 75.76%
No description has been provided for this image
📊 Classification Report:
              precision    recall  f1-score   support

  Fibroblast       0.67      0.70      0.68       343
Inflammatory       0.67      0.74      0.71       607
  Epithelial       0.94      0.85      0.89      1106
      Others       0.45      0.50      0.47       271

    accuracy                           0.76      2327
   macro avg       0.68      0.70      0.69      2327
weighted avg       0.77      0.76      0.76      2327

No description has been provided for this image

Semi-Supervised CNN Model Observations¶

This model extends the previously tested CustomSmallImageCNN by incorporating semi-supervised learning. Using pseudo-labeling, it leverages an additional 10,384 unlabeled images, of which 1,687 were confidently classified (≥88% softmax confidence) and added to the training set. These pseudo-labeled samples were combined with the original balanced training data to enrich the dataset and provide the model with more diverse visual representations of cell types. This approach is particularly valuable in histopathology, where acquiring manual annotations can be time-consuming and expensive.

The training setup remains consistent, using focal loss with class weights to handle class imbalance and Adam optimizer with a scheduled learning rate decay. Compared to the previous model trained only on labeled data, this model achieved a test accuracy of 75.76%, demonstrating an improvement of roughly 2.4 percentage points. This suggests that the inclusion of high-confidence pseudo-labeled data helped the model generalize better, especially on unseen test samples. The loss curves show healthy convergence, and accuracy curves indicate steady improvements in both training and validation phases.

Loss & Accuracy¶

The training loss drops rapidly during early epochs and continues to decline steadily, while the validation loss decreases more gradually and eventually plateaus. Training accuracy increases consistently, ending above 77%, and validation accuracy reaches a peak of 72.82%. This convergence behavior, combined with the minimal gap between training and validation accuracy, implies that the model benefitted from the additional data without overfitting. Compared to previous iterations, this model achieves the most balanced and stable accuracy progression.

Classification Matrix¶

Performance across classes has improved compared to the baseline. The Epithelial class (class 2) maintains high performance with an F1-score of 0.89. Notably, the Inflammatory class (class 1) now has the highest F1-score of 0.71, indicating that the pseudo-labeled samples may have helped reinforce this category. The Fibroblast class (class 0) improves to an F1-score of 0.68, and even the challenging Others class (class 3) reaches 0.47, its best result so far. The macro average F1-score increased to 0.69, and the weighted F1-score rose to 0.76, demonstrating more equitable performance across all four classes. These results validate the effectiveness of semi-supervised learning in boosting multiclass histological image classification.

Ultimate Judgement¶

Binary Model¶

Best Model Selected: DenseNet

After evaluating all three models we determined that the DenseNet based model was the most effective for the binary classification task.

Despite the fact that all three models had similar accuracy, around 90%, there were some distinctions in the performance of their validation sets. The BaseCNN Model's validation loss from epoch to epoch did not decrease rather fluctuated with no positive or negative trend around the average loss of the training set. The validation sets accuracy for BaseCNN also fluctuated around the average of the training sets accuracy. So although the test accuracy of the model was 89%, this unusual behaviour of the validation set indicates that BaseCNN Model's performance on unseen may be impacted. The Residual Model's validation set loss was closer to a "good fit" as it decreased slightly and then stabilised at values higher than the training set. For accuracy, the validation set only increased by 4% from its first epoch to last. Though this is an improvement to the BaseCNN model the gap between the training and validation accuracy/loss was bigger than the DenseNet model. The DenseNet model's validation set loss and accuracy emulated that of the training sets. It was significantly closer to converging with the training set than the Residual Model, further justifying it as the better model.

The DenseNet Model's extensive connectivity and the consequential feature reuse means that the model learns representations even with fewer parameters. In the case of classifiying colon cancer, where the subtle variations in colour and texture define the classification, the model is able to retain both high level and low level features more efficiently than BaseCNN or the Residual Model. Another benefit of DenseNet's architecture is that each layer is directly connected, which makes training the model easier and reduces the issue of vanishing gradients.

Overall, the DenseNet model has multiple advantages in it's architecture and evidently its performance that make it more compatible with histpathology than other models. The model's ability to fit to the data was also demonstrated in the training and validation loss/accuracy across its training epochs. Adjustments to the training loop of the model, learning rate scheduling and weight decay, also helped the model to improve its classification ability. Therefor, the custom DenseNet Model architecture we created for this model both the best performing and optimised for generalisation out of the models we created for this binary classification task.

Multi-Class Classification¶

Best Model Selected: Semi-Supervised Custom CNN

After evaluating the BaseCNN, Custom Small Image CNN, and the Semi-Supervised CNN model using additional unlabeled data, we determined that the semi-supervised model was the most effective for multiclass cell type classification. While all models achieved similar performance in the 70–74% accuracy range, the semi-supervised model reached the highest test accuracy of 75.76% and showed the most consistent validation accuracy trend across epochs. The BaseCNN exhibited more fluctuation in validation accuracy and loss, and while the Custom CNN showed promising learning behavior, its improvement was ultimately surpassed once pseudo-labeled samples were integrated into training.

The semi-supervised model distinguished itself not only by raw accuracy but by its ability to generalize from limited labeled data to a broader pool of real-world samples. Using pseudo-labeling, the model effectively extracted information from 1,687 confidently predicted unlabeled images, expanding its exposure to subtle morphological variations in histopathology data. This use of semi-supervised learning aligns well with the constraints of real medical datasets, where manual annotation is resource-intensive. The model maintained consistent training and validation loss curves, showing no overfitting, and also provided improved performance on previously underrepresented classes like “Others” which is especially important in clinical scenarios where rare cell types must still be reliably detected.

The ultimate judgement to select this model goes beyond raw accuracy by considering scalability, efficiency, and domain-specific robustness. The semi-supervised model demonstrates a pragmatic advantage in real-world biomedical environments: it reduces reliance on extensive expert annotation and still improves model generalization. Its stable learning dynamics, balanced class-wise performance (as seen in the F1 scores), and compatibility with uncurated data sources make it both the best performing and the most optimized for deployment in practical histopathological classification tasks.

Independent Evaluation¶

This section contains comparitive analysis of our best models for cell type and cancer cell tasks (defined in the ultimate judgement). This exercise is to explore how our models might perform in a real world setting, where retraining or changing the architecture of the model is not always possible.

Cancer Cell Classification (Binary)¶

The DenseNet model that we chose as our best performing model has been extensively used in research pertaining to image classification in the histopathology field. It had an overall accuracy of 90%, f1-score of 0.90, recall of 0.90 and precision of 0.91 for the isCancerous class.

We shall be comparing it to a 2024 paper that used the features of VVG16 and ResNet101 models to create a Fused Feature Vector (FFV)[1]. This model reportedly >95% accuracy rate at classifying colorectal cancer images, which outperformed our DenseNet implementation. Their model utilised emsemble based feature fushion from two large pre-trained models, which likely benefitted from transfer learning on large scale data sets. The datasets used also had a significantly higher resolution to our 27x27 colon cancer dataset, which makes the increased representation nessecary to accomodate for the added information. When comparing to this work, we acknowledge that the differences in architecture, model depth, use of pretrained models and resolution and size of the training dataset limit a direct comparison. However our DenseNet model's accuracy of 90% is in a close range to the FFV and it achieved this with a much simpler architecture. Therefor our model may be more viable for resource limited enviroments like healthcare.

The next research paper proposed a Deep Convolutional Neural Network (DCNN) that achieved 94.7% accuracy in classifying colon adenocarcinomas [2]. The DCNN Model replaced with a sigmoid activation strategy that was designed to optimise transfer learning. Similarly to our DenseNet Model the DCNN Model preserved the original resolution of the images by avoiding excessive downsampling. This helped both models improve pattern revocognition even in small details. The paper also includes a full preprocessing pipeline, unlike our DenseNet Model with takes the raw inputs without any feature engineering. Although it is hard to compare the models fairly, the inclusion of a extensive preprocessing pipeline could be used to improve the DenseNet Model in future research. Ultimately, our DenseNet Model falls short compared to the accuracy of the DCNN Model proposed.

A key disadvantage to the external model's discussed is that they are very complex, and due to the deep nature of their architecture would be difficult to explain their classification process. Our DenseNet model takes on a much more simplistic architecture using minimally processed image inputs. This makes it more interpretable, which is nessecary if used to aid medical experts in the diagnostic process.

Cell Type Classification (Multi-Class)¶

The semi-supervised CustomSmallImageCNN model, selected as our best performer for multiclass histopathological classification, can be further validated by comparing it against models from relevant literature. Our model achieved a test accuracy of 75.76%, using a dataset derived from the publicly available colon cancer histology images published by Sirinukunwattana et al. (2016) in IEEE Transactions on Medical Imaging. Their original study proposed a Locality-Sensitive Deep Learning framework which achieved high classification performance by exploiting spatial context and neighborhood information in cell detection and tissue classification. While our implementation focused solely on patch-wise classification using a small image CNN and additional unlabeled samples, it demonstrates how a simpler architecture can still yield reliable results with much lower computational overhead.

When compared to recent literature in the multiclass domain, our model falls short in raw accuracy compared to studies that utilize extensive transfer learning or ensemble approaches. For instance, more complex frameworks using pre-trained networks such as VGG16, ResNet101, and hybrid feature fusion have reported accuracies exceeding 90%. However, these models often rely on high-resolution data, aggressive preprocessing, and domain-specific tuning—making them less feasible in practical clinical environments with limited resources or real-time constraints. Our approach used 27×27 image patches, minimal preprocessing, and lightweight architecture, demonstrating a favorable trade-off between performance, interpretability, and scalability.

A key benefit of our model is its interpretable architecture and ability to incorporate unlabeled data through pseudo-labeling. This makes it suitable for real-world scenarios such as small clinics or research labs where access to large annotated datasets is limited. Furthermore, because our model does not rely on pretrained backbones or large image inputs, it is computationally affordable and easier to deploy on edge devices or integrated into clinical workflows. Compared to the architecture proposed by Sirinukunwattana et al., which integrates spatial relationships and local cell neighborhood features, our model is more patch-centric but still aligns with the dataset’s structural assumptions. Thus, while it does not outperform state-of-the-art methods in absolute terms, it presents a cost-effective, flexible, and adaptable solution for multiclass histopathology classification in resource-limited settings.

Appendix¶

[1] V. Rajinikanth, R. Mohan and M. Narayanan, "Deep Learning and Features Fusion for Colorectal Cancer Detection from Histopathology Images," 2024 9th International Conference on Communication and Electronics Systems (ICCES), Coimbatore, India, 2024, pp. 1406-1411, doi: 10.1109/ICCES63552.2024.10859542.

[2] J. Smida, M. K. Azizi, A. S. C. Bose and A. Smida, "An Effective Approach for Detecting Colon Cancer Using Deep Convolutional Neural Network," 2024 IEEE International Conference on Advanced Systems and Emergent Technologies (IC_ASET), Hammamet, Tunisia, 2024, pp. 1-6, doi: 10.1109/IC_ASET61847.2024.10596175.

[3] K. Sirinukunwattana, S. E. A. Raza, Y. Tsang, D. R. J. Snead, I. A. Cree and N. M. Rajpoot, “Locality Sensitive Deep Learning for Detection and Classification of Nuclei in Routine Colon Cancer Histology Images,” IEEE Transactions on Medical Imaging, vol. 35, no. 5, pp. 1196–1206, May 2016. doi: 10.1109/TMI.2016.2525803

[4] Keylabs, “Best Practices Image preprocessing in image Classification | Keylabs,” Keylabs: Latest News and Updates, Aug. 14, 2024. https://keylabs.ai/blog/best-practices-for-image-preprocessing-in-image-classification/

[5] M. Politi, “Binary image classification in PyTorch - TDS Archive - Medium,” Medium, Jun. 01, 2022. [Online]. Available: https://medium.com/data-science/binary-image-classification-in-pytorch-5adf64f8c781

[6] C.-Y. Chang, “Building a Customized Residual CNN with PyTorch - Chen-Yu Chang - Medium,” Medium, May 31, 2024. [Online]. Available: https://medium.com/@chen-yu/building-a-customized-residual-cnn-with-pytorch-471810e894ed

[7] M. Kothari, “Comparison of different deep learning models for image classification,” Medium, Dec. 14, 2021. [Online]. Available: https://medium.com/@mahakkothari190.mk/comparison-of-different-deep-learning-models-for-image-classification-1c49f1159d7a

[8] rupert ai, “The U-Net (actually) explained in 10 minutes,” YouTube. May 05, 2023. [Online]. Available: https://www.youtube.com/watch?v=NhdzGfB1q74