diff --git a/recognition/UNet3D_47473289/README.md b/recognition/UNet3D_47473289/README.md new file mode 100644 index 000000000..b51351991 --- /dev/null +++ b/recognition/UNet3D_47473289/README.md @@ -0,0 +1,31 @@ +# Improved 3D Unet for Prostate MRI Segmentation + +## Overview + +The implementation of an improved 3D UNet segmenting the (downsampled) Prostate 3D data set with all labels having a minimum Dice similarity coefficient of 0.7 on the test set.The data is augmented using the appropriate transforms in PyTorch and is loaded from a Nifti file format ".nii.gz" + + +## Improved UNet3D Model + +The Improved 3D UNet model is a CNN for 3D image segmentation. +The model is based on the keras brats UNet model using an encoder and decoder with a bottleneck inbetween. The encoder applies convolutions to the input images and reduces its dimensions to identify patterns in the images. The decoder applies up-convolutions to the downsampled images while concatenating at each step with the relative encoder convolution in its respective layer. The bottleneck is the connection between the encoder and the decoder. + +The dataset is split into training data, validation data and testing data. +These datasets are a subset of the total data loaded, split into 70% training, 20% validation and 10% testing. + +Training the model has the goal of minimising the DSC loss on the training data, used to maximise +the DSC in evaluation mode on the validation set. Appropriate transformations are applied to the +data to maximise the speed of training. + +## Dependencies + +PyTorch version: 2.4.0 +Numpy version: 1.26.4 +Tqdm version: 4.66.4 +Nibabel version: 5.3.1 +Python version: 3.12.3 +Matplotlib version: 3.8.4 + +## Results + +Average DSC across classes = 0.6122 \ No newline at end of file diff --git a/recognition/UNet3D_47473289/dataset.py b/recognition/UNet3D_47473289/dataset.py new file mode 100644 index 000000000..d8f9fb3f9 --- /dev/null +++ b/recognition/UNet3D_47473289/dataset.py @@ -0,0 +1,136 @@ +# Data loader for leading and preprocessing data. +from torch.utils.data import DataLoader, TensorDataset, Dataset +import torch +import nibabel as nib +import numpy as np +import tqdm as tqdm +#from sklearn.model_selection import train_test_split +#from scipy.ndimage import zoom +import torch.nn as nn +import glob + + +def to_channels ( arr : np . ndarray , dtype = np . uint8 ) -> np . ndarray : + channels = np . unique ( arr ) + res = np . zeros ( arr . shape + ( len ( channels ) ,) , dtype = dtype ) + for c in channels : + c = int ( c ) + res [... , c : c +1][ arr == c ] = 1 + + return res + + +def load_data_3D(imageNames, normImage=False , categorical=False, dtype=np.float32, + getAffines=False, orient=False, early_stop=False): + ''' + Load medical image data from names , cases list provided into a list for each . + + This function pre - allocates 5 D arrays for conv3d to avoid excessive memory ↘ + usage . + + normImage : bool ( normalise the image 0.0 -1.0) + orient : Apply orientation and resample image ? Good for images with large slice ↘ + thickness or anisotropic resolution + dtype : Type of the data . If dtype = np . uint8 , it is assumed that the data is ↘ + labels + early_stop : Stop loading pre - maturely ? Leaves arrays mostly empty , for quick ↘ + loading and testing scripts . + ''' + affines = [] + + # ~ interp = ' continuous ' + interp = 'linear' + if dtype == np.uint8: # assume labels + interp = 'nearest' + + # get fixed size + num = len(imageNames) + + # ~ testResultName = "oriented.nii.gz" + # ~ niftiImage.to_filename(testResultName) + first_case = nib.load(imageNames[0]).get_fdata(caching='unchanged') + + if len(first_case.shape) == 4: + first_case = first_case [:,:,:,0] # sometimes extra dims , remove + if categorical: + first_case = to_channels(first_case, dtype=dtype) + rows, cols, depth, channels = first_case.shape + images = np.zeros((num, rows, cols, depth, channels) , dtype=dtype) + + else: + rows, cols, depth = first_case.shape + images = np.zeros((num, rows, cols, depth), dtype=dtype) + + + for i ,inName in enumerate(tqdm.tqdm(imageNames)): + niftiImage = nib.load(inName) + if len(first_case.shape) == 3: + first_case = np.expand_dims(first_case, axis=0) + + inImage = niftiImage.get_fdata(caching='unchanged') # read disk only + affine = niftiImage.affine + + if len (inImage.shape) == 4: + inImage = inImage [:,:,:,0] # sometimes extra dims in HipMRI_study data + + inImage = inImage [:,:,:depth] # clip slices + inImage = inImage.astype(dtype) + + if normImage: + # ~ inImage = inImage / np . linalg . norm ( inImage ) + # ~ inImage = 255. * inImage / inImage . max () + inImage = (inImage - inImage . mean () ) / inImage . std () + if categorical : + inImage = to_channels(inImage, dtype=dtype) + # ~ images [i,:,:,:,:] = inImage + images [i,:inImage.shape[0],:inImage.shape[1],:inImage.shape [2],:inImage.shape[3]] = inImage # with pad + + else : + # ~ images [i,:,:,:] = inImage + images [i,:inImage.shape[0],:inImage.shape[1],:inImage.shape[2]] = inImage # with pad + + affines.append(affine) + if i > 20 and early_stop: + print("STOPPED EARly") + break + + if getAffines: + return images, affines + else: + print("Returned images") + return images + +class MyCustomDataset(Dataset): + def __init__(self): + # load all nii handle in a list + #self.image_paths = glob.glob(f'{"/Users/charl/Documents/3710Report/PatternAnalysis-2024/recognition/semantic_MRs_anon"}/**/*.nii.gz', recursive=True) + #self.label_paths = glob.glob(f'{"/Users/charl/Documents/3710Report/PatternAnalysis-2024/recognition/semantic_labels_only"}/**/*.nii.gz', recursive=True) + self.label_paths = glob.glob(f'{"/home/groups/comp3710/HipMRI_Study_open/semantic_labels_only"}/**/*.nii.gz', recursive=True) + self.image_paths = glob.glob(f'{"/home/groups/comp3710/HipMRI_Study_open/semantic_MRs"}/**/*.nii.gz', recursive=True) + + self.up = torch.nn.Upsample(size=(128,128,128)) + + self.classes = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5} + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, i): + image = load_data_3D([self.image_paths[i]],early_stop=True) + label = load_data_3D([self.label_paths[i]],early_stop=True) + print(image.shape) + image = torch.tensor(image).float().unsqueeze(1) + label = torch.tensor([self.classes.get(cla.item(), 0) for cla in label.flatten()]).reshape(label.shape) + label = nn.functional.one_hot(label.squeeze(0), num_classes=6).permute(3,1,0,2).float() + label = label.unsqueeze(0) + image = self.up(image).squeeze(0) + label = self.up(label).squeeze(0) + print(image.shape) + print(label.shape) + + return image, label + + + +dataset = MyCustomDataset() +print(dataset[10]) +print(dataset[10][1].shape) \ No newline at end of file diff --git a/recognition/UNet3D_47473289/modules.py b/recognition/UNet3D_47473289/modules.py new file mode 100644 index 000000000..a3e95be2f --- /dev/null +++ b/recognition/UNet3D_47473289/modules.py @@ -0,0 +1,118 @@ +# Source code of the components of the UNet3D Model. + +import torch +import torch.nn as nn + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvBlock, self).__init__() + self.block = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm3d(out_channels), + nn.ReLU(), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm3d(out_channels), + nn.ReLU() + ) + + def forward(self, x): + return self.block(x) + +class UpConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(UpConv, self).__init__() + self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2) + + def forward(self, x): + return self.up(x) + +class ImprovedUNet3D(nn.Module): + def __init__(self, in_channels=1, out_channels=1): + super(ImprovedUNet3D, self).__init__() + print("UNET INIT") + # Contracting path + self.conv1 = ConvBlock(in_channels, 64) + self.pool1 = nn.MaxPool3d(2) + + self.conv2 = ConvBlock(64, 128) + self.pool2 = nn.MaxPool3d(2) + + self.conv3 = ConvBlock(128, 256) + self.pool3 = nn.MaxPool3d(2) + + self.conv4 = ConvBlock(256, 512) + self.pool4 = nn.MaxPool3d(2) + + self.bottleneck = ConvBlock(512, 1024) + + # Expansive path + self.upconv4 = UpConv(1024, 512) + self.conv_up4 = ConvBlock(1024, 512) + + self.upconv3 = UpConv(512, 256) + self.conv_up3 = ConvBlock(512, 256) + + self.upconv2 = UpConv(256, 128) + self.conv_up2 = ConvBlock(256, 128) + + self.upconv1 = UpConv(128, 64) + self.conv_up1 = ConvBlock(128, 64) + + self.final_conv = nn.Conv3d(64, out_channels, kernel_size=1) + + def forward(self, x): + x1 = self.conv1(x) + p1 = self.pool1(x1) + + x2 = self.conv2(p1) + p2 = self.pool2(x2) + + x3 = self.conv3(p2) + p3 = self.pool3(x3) + + x4 = self.conv4(p3) + p4 = self.pool4(x4) + + bn = self.bottleneck(p4) + + u4 = self.upconv4(bn) + u4 = torch.cat([u4, x4], dim=1) + u4 = self.conv_up4(u4) + + + u3 = self.upconv3(u4) + u3 = torch.cat([u3, x3], dim=1) + u3 = self.conv_up3(u3) + + u2 = self.upconv2(u3) + u2 = torch.cat([u2, x2], dim=1) + u2 = self.conv_up2(u2) + + u1 = self.upconv1(u2) + u1 = torch.cat([u1, x1], dim=1) + u1 = self.conv_up1(u1) + + return self.final_conv(u1) + +class DSC(nn.Module): + """ + Implementation of Dice-Sorensen Coefficient loss + 2 * |predict * target|/|predict|+|target| + """ + def __init__(self): + super(DSC, self).__init__() + # Smooth to avoid 0 division + self.smooth = 1.0 + + def forward(self, y_pred, y_true): + assert y_pred.size() == y_true.size() + + intersection = (y_pred * y_true).sum(dim=[2,3,4]) + # Calculate DSC + dsc = (2. * intersection + self.smooth) / ( + y_pred.sum(dim=[2,3,4]) + y_true.sum(dim=[2,3,4]) + self.smooth + ) + + dsc = dsc.mean() + print(1. - dsc) + return 1. - dsc diff --git a/recognition/UNet3D_47473289/predict.py b/recognition/UNet3D_47473289/predict.py new file mode 100644 index 000000000..a489fa2c6 --- /dev/null +++ b/recognition/UNet3D_47473289/predict.py @@ -0,0 +1,64 @@ +# Example usage of the trained UNet3D model. +import torch +import matplotlib.pyplot as plt + +from skimage.util import montage +import numpy as np + + +#dataset = MyCustomDataset() +#_, test = torch.utils.data.random_split(cust_dataset, [0.9, 0.1]) +#test_loader = DataLoader(test) +#model = ImprovedUNet3D(in_channels=1, out_channels=6).cuda() +#model.load_state_dict(torch.load('./new_folder/model.pth')) +def predict(model, test_loader): + model.eval() + with torch.no_grad(): + for test_images, test_masks in test_loader: + # moves test_images to be on gpu + vis = Visulizer() + vis.visualize(test_images, test_masks) + + test_images = test_images.cuda() + test_outputs = model(test_images) + predicted_masks = torch.argmax(test_outputs, dim=1) + + plt.figure(figsize=(10, 5)) + # The base input image + plt.subplot(1, 3, 1) + plt.imshow(test_images[0].cpu().squeeze(), cmap='gray') + plt.title("Input Image") + + # The segmentation mask + plt.subplot(1, 3, 2) + plt.imshow(test_masks[0].cpu().argmax(dim=0).squeeze(), cmap='gray') + plt.title("Segmentation Mask") + + # The predicted Mask + plt.subplot(1, 3, 3) + plt.imshow(predicted_masks[0].cpu().squeeze(), cmap='gray') + plt.title("Predicted Mask") + + # Save image + output_path = "output.png" + plt.savefig(output_path) + plt.show() + break + + +class Visulizer: + def montage_nd(self, image): + if len(image.shape)>3: + return montage(np.stack([self.montage_nd(img) for img in image],0)) + elif len(image.shape)==3: + return montage(image) + else: + print('Input less than 3d image, returning original') + return image + + def visualize(self, image, mask): + fig, axs = plt.subplots(1, 2, figsize = (20, 15 * 2)) + axs[0].imshow(self.montage_nd(image[...,0]), cmap = 'bone') + axs[1].imshow(self.montage_nd(mask[...,0]), cmap = 'bone') + plt.show() + diff --git a/recognition/UNet3D_47473289/train.py b/recognition/UNet3D_47473289/train.py new file mode 100644 index 000000000..bac5c9c68 --- /dev/null +++ b/recognition/UNet3D_47473289/train.py @@ -0,0 +1,84 @@ +# Source code for training, validating, testing and saving the UNet3D Model. +from modules import * +from dataset import * +from predict import * +import torch +import torch.optim as optim +from sklearn.model_selection import train_test_split +import time +import glob +import os +torch.cuda.empty_cache() + + +cust_dataset = MyCustomDataset() +train, val, test = torch.utils.data.random_split(cust_dataset, [0.7, 0.2, 0.1]) +train = DataLoader(train) +val = DataLoader(val) +test = DataLoader(test) +model = ImprovedUNet3D(in_channels=1, out_channels=6).cuda() +optimizer = optim.Adam(model.parameters(), lr=1e-4) + + +criterion = DSC() + +learn_rate = 0.1 +# Alternates minimum and maximum learning rate +l_sched_1 = optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.005, max_lr = learn_rate, step_size_up=15, step_size_down=15, mode='triangular', verbose= False) +# Linearly reduces learning rate +l_sched_2 = optim.lr_scheduler.LinearLR(optimizer , start_factor=0.005/learn_rate, end_factor=0.005/5, verbose=False) +# Sequentially applies sched 1 and 2 +scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[l_sched_1, l_sched_2], milestones=[30]) + + +# No. of Epochs (Number of times all training data is run) +n_epochs = 5 +train_losses = [] +print(">> Training <<") +start = time.time() +for epoch in range(n_epochs): + epstart = time.time() + # Set training mode + model.train() + running_loss = 0.0 + print(len(train)) + for i, (inputs, labels) in enumerate(train): + inputs = inputs.cuda() + labels = labels.cuda() + + optimizer.zero_grad() + output = model(inputs) + + loss = criterion(output, labels) + + loss.backward() + + optimizer.step() + + running_loss += loss.item() + scheduler.step() + # Set evaluation mode for validation + train_loss = running_loss / len(train) + train_losses.append(train_loss) + model.eval() + scores = [] + with torch.no_grad(): + for i, (val_inputs, val_labels, ) in enumerate(val): + val_inputs = val_inputs.cuda() + val_labels = val_labels.cuda() + output = model(val_inputs) + # Calculating Dice score for each pass + score = 1 - criterion(output, val_labels) + scores.append(score) + # Average validation score for all batches + mean_score = sum(scores) / len(scores) + epend = time.time() + print(f"Epoch {epoch + 1}/{n_epochs}, Training Loss: {loss.item()}, Validation DSC: {mean_score}") + print(f"Epoch {epoch + 1}/{n_epochs} took {epend - epstart} seconds") + + +predict(model, test) + +end = time.time() +total_time = end - start +print(f"Training and Validation took {total_time} seconds or {total_time/60} minutes.")