Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions recognition/UNet3D_47473289/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Improved 3D Unet for Prostate MRI Segmentation

## Overview

The implementation of an improved 3D UNet segmenting the (downsampled) Prostate 3D data set with all labels having a minimum Dice similarity coefficient of 0.7 on the test set.The data is augmented using the appropriate transforms in PyTorch and is loaded from a Nifti file format ".nii.gz"


## Improved UNet3D Model

The Improved 3D UNet model is a CNN for 3D image segmentation.
The model is based on the keras brats UNet model using an encoder and decoder with a bottleneck inbetween. The encoder applies convolutions to the input images and reduces its dimensions to identify patterns in the images. The decoder applies up-convolutions to the downsampled images while concatenating at each step with the relative encoder convolution in its respective layer. The bottleneck is the connection between the encoder and the decoder.

The dataset is split into training data, validation data and testing data.
These datasets are a subset of the total data loaded, split into 70% training, 20% validation and 10% testing.

Training the model has the goal of minimising the DSC loss on the training data, used to maximise
the DSC in evaluation mode on the validation set. Appropriate transformations are applied to the
data to maximise the speed of training.

## Dependencies

PyTorch version: 2.4.0
Numpy version: 1.26.4
Tqdm version: 4.66.4
Nibabel version: 5.3.1
Python version: 3.12.3
Matplotlib version: 3.8.4

## Results

Average DSC across classes = 0.6122
136 changes: 136 additions & 0 deletions recognition/UNet3D_47473289/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Data loader for leading and preprocessing data.
from torch.utils.data import DataLoader, TensorDataset, Dataset
import torch
import nibabel as nib
import numpy as np
import tqdm as tqdm
#from sklearn.model_selection import train_test_split
#from scipy.ndimage import zoom
import torch.nn as nn
import glob


def to_channels ( arr : np . ndarray , dtype = np . uint8 ) -> np . ndarray :
channels = np . unique ( arr )
res = np . zeros ( arr . shape + ( len ( channels ) ,) , dtype = dtype )
for c in channels :
c = int ( c )
res [... , c : c +1][ arr == c ] = 1

return res


def load_data_3D(imageNames, normImage=False , categorical=False, dtype=np.float32,
getAffines=False, orient=False, early_stop=False):
'''
Load medical image data from names , cases list provided into a list for each .

This function pre - allocates 5 D arrays for conv3d to avoid excessive memory ↘
usage .

normImage : bool ( normalise the image 0.0 -1.0)
orient : Apply orientation and resample image ? Good for images with large slice ↘
thickness or anisotropic resolution
dtype : Type of the data . If dtype = np . uint8 , it is assumed that the data is ↘
labels
early_stop : Stop loading pre - maturely ? Leaves arrays mostly empty , for quick ↘
loading and testing scripts .
'''
affines = []

# ~ interp = ' continuous '
interp = 'linear'
if dtype == np.uint8: # assume labels
interp = 'nearest'

# get fixed size
num = len(imageNames)

# ~ testResultName = "oriented.nii.gz"
# ~ niftiImage.to_filename(testResultName)
first_case = nib.load(imageNames[0]).get_fdata(caching='unchanged')

if len(first_case.shape) == 4:
first_case = first_case [:,:,:,0] # sometimes extra dims , remove
if categorical:
first_case = to_channels(first_case, dtype=dtype)
rows, cols, depth, channels = first_case.shape
images = np.zeros((num, rows, cols, depth, channels) , dtype=dtype)

else:
rows, cols, depth = first_case.shape
images = np.zeros((num, rows, cols, depth), dtype=dtype)


for i ,inName in enumerate(tqdm.tqdm(imageNames)):
niftiImage = nib.load(inName)
if len(first_case.shape) == 3:
first_case = np.expand_dims(first_case, axis=0)

inImage = niftiImage.get_fdata(caching='unchanged') # read disk only
affine = niftiImage.affine

if len (inImage.shape) == 4:
inImage = inImage [:,:,:,0] # sometimes extra dims in HipMRI_study data

inImage = inImage [:,:,:depth] # clip slices
inImage = inImage.astype(dtype)

if normImage:
# ~ inImage = inImage / np . linalg . norm ( inImage )
# ~ inImage = 255. * inImage / inImage . max ()
inImage = (inImage - inImage . mean () ) / inImage . std ()
if categorical :
inImage = to_channels(inImage, dtype=dtype)
# ~ images [i,:,:,:,:] = inImage
images [i,:inImage.shape[0],:inImage.shape[1],:inImage.shape [2],:inImage.shape[3]] = inImage # with pad

else :
# ~ images [i,:,:,:] = inImage
images [i,:inImage.shape[0],:inImage.shape[1],:inImage.shape[2]] = inImage # with pad

affines.append(affine)
if i > 20 and early_stop:
print("STOPPED EARly")
break

if getAffines:
return images, affines
else:
print("Returned images")
return images

class MyCustomDataset(Dataset):
def __init__(self):
# load all nii handle in a list
#self.image_paths = glob.glob(f'{"/Users/charl/Documents/3710Report/PatternAnalysis-2024/recognition/semantic_MRs_anon"}/**/*.nii.gz', recursive=True)
#self.label_paths = glob.glob(f'{"/Users/charl/Documents/3710Report/PatternAnalysis-2024/recognition/semantic_labels_only"}/**/*.nii.gz', recursive=True)
self.label_paths = glob.glob(f'{"/home/groups/comp3710/HipMRI_Study_open/semantic_labels_only"}/**/*.nii.gz', recursive=True)
self.image_paths = glob.glob(f'{"/home/groups/comp3710/HipMRI_Study_open/semantic_MRs"}/**/*.nii.gz', recursive=True)

self.up = torch.nn.Upsample(size=(128,128,128))

self.classes = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5}
def __len__(self):
return len(self.image_paths)

def __getitem__(self, i):
image = load_data_3D([self.image_paths[i]],early_stop=True)
label = load_data_3D([self.label_paths[i]],early_stop=True)
print(image.shape)
image = torch.tensor(image).float().unsqueeze(1)
label = torch.tensor([self.classes.get(cla.item(), 0) for cla in label.flatten()]).reshape(label.shape)
label = nn.functional.one_hot(label.squeeze(0), num_classes=6).permute(3,1,0,2).float()
label = label.unsqueeze(0)
image = self.up(image).squeeze(0)
label = self.up(label).squeeze(0)
print(image.shape)
print(label.shape)

return image, label



dataset = MyCustomDataset()
print(dataset[10])
print(dataset[10][1].shape)
118 changes: 118 additions & 0 deletions recognition/UNet3D_47473289/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Source code of the components of the UNet3D Model.

import torch
import torch.nn as nn

class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(),
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU()
)

def forward(self, x):
return self.block(x)

class UpConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(UpConv, self).__init__()
self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)

def forward(self, x):
return self.up(x)

class ImprovedUNet3D(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super(ImprovedUNet3D, self).__init__()
print("UNET INIT")
# Contracting path
self.conv1 = ConvBlock(in_channels, 64)
self.pool1 = nn.MaxPool3d(2)

self.conv2 = ConvBlock(64, 128)
self.pool2 = nn.MaxPool3d(2)

self.conv3 = ConvBlock(128, 256)
self.pool3 = nn.MaxPool3d(2)

self.conv4 = ConvBlock(256, 512)
self.pool4 = nn.MaxPool3d(2)

self.bottleneck = ConvBlock(512, 1024)

# Expansive path
self.upconv4 = UpConv(1024, 512)
self.conv_up4 = ConvBlock(1024, 512)

self.upconv3 = UpConv(512, 256)
self.conv_up3 = ConvBlock(512, 256)

self.upconv2 = UpConv(256, 128)
self.conv_up2 = ConvBlock(256, 128)

self.upconv1 = UpConv(128, 64)
self.conv_up1 = ConvBlock(128, 64)

self.final_conv = nn.Conv3d(64, out_channels, kernel_size=1)

def forward(self, x):
x1 = self.conv1(x)
p1 = self.pool1(x1)

x2 = self.conv2(p1)
p2 = self.pool2(x2)

x3 = self.conv3(p2)
p3 = self.pool3(x3)

x4 = self.conv4(p3)
p4 = self.pool4(x4)

bn = self.bottleneck(p4)

u4 = self.upconv4(bn)
u4 = torch.cat([u4, x4], dim=1)
u4 = self.conv_up4(u4)


u3 = self.upconv3(u4)
u3 = torch.cat([u3, x3], dim=1)
u3 = self.conv_up3(u3)

u2 = self.upconv2(u3)
u2 = torch.cat([u2, x2], dim=1)
u2 = self.conv_up2(u2)

u1 = self.upconv1(u2)
u1 = torch.cat([u1, x1], dim=1)
u1 = self.conv_up1(u1)

return self.final_conv(u1)

class DSC(nn.Module):
"""
Implementation of Dice-Sorensen Coefficient loss
2 * |predict * target|/|predict|+|target|
"""
def __init__(self):
super(DSC, self).__init__()
# Smooth to avoid 0 division
self.smooth = 1.0

def forward(self, y_pred, y_true):
assert y_pred.size() == y_true.size()

intersection = (y_pred * y_true).sum(dim=[2,3,4])
# Calculate DSC
dsc = (2. * intersection + self.smooth) / (
y_pred.sum(dim=[2,3,4]) + y_true.sum(dim=[2,3,4]) + self.smooth
)

dsc = dsc.mean()
print(1. - dsc)
return 1. - dsc
64 changes: 64 additions & 0 deletions recognition/UNet3D_47473289/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Example usage of the trained UNet3D model.
import torch
import matplotlib.pyplot as plt

from skimage.util import montage
import numpy as np


#dataset = MyCustomDataset()
#_, test = torch.utils.data.random_split(cust_dataset, [0.9, 0.1])
#test_loader = DataLoader(test)
#model = ImprovedUNet3D(in_channels=1, out_channels=6).cuda()
#model.load_state_dict(torch.load('./new_folder/model.pth'))
def predict(model, test_loader):
model.eval()
with torch.no_grad():
for test_images, test_masks in test_loader:
# moves test_images to be on gpu
vis = Visulizer()
vis.visualize(test_images, test_masks)

test_images = test_images.cuda()
test_outputs = model(test_images)
predicted_masks = torch.argmax(test_outputs, dim=1)

plt.figure(figsize=(10, 5))
# The base input image
plt.subplot(1, 3, 1)
plt.imshow(test_images[0].cpu().squeeze(), cmap='gray')
plt.title("Input Image")

# The segmentation mask
plt.subplot(1, 3, 2)
plt.imshow(test_masks[0].cpu().argmax(dim=0).squeeze(), cmap='gray')
plt.title("Segmentation Mask")

# The predicted Mask
plt.subplot(1, 3, 3)
plt.imshow(predicted_masks[0].cpu().squeeze(), cmap='gray')
plt.title("Predicted Mask")

# Save image
output_path = "output.png"
plt.savefig(output_path)
plt.show()
break


class Visulizer:
def montage_nd(self, image):
if len(image.shape)>3:
return montage(np.stack([self.montage_nd(img) for img in image],0))
elif len(image.shape)==3:
return montage(image)
else:
print('Input less than 3d image, returning original')
return image

def visualize(self, image, mask):
fig, axs = plt.subplots(1, 2, figsize = (20, 15 * 2))
axs[0].imshow(self.montage_nd(image[...,0]), cmap = 'bone')
axs[1].imshow(self.montage_nd(mask[...,0]), cmap = 'bone')
plt.show()

Loading