Mastering U-Net: A Step-by-Step Guide to Segmentation from Scratch with PyTorch (2024)

FernandoPC25

·

Follow

15 min read

·

Apr 25, 2024

--

In the field of computer vision, capturing the world as humans perceive and understand it has consistently been cornerstone of groundbreaking advancements. One of the pivotal techniques that enable this understanding is image segmentation. Segmentation is the process of dividing an image into multiple segments or regions to simplify its representation and make it easier to analyze.

So, why is segmentation so crucial? At its core, segmentation provides context and meaning to individual pixels, transforming raw images into structured data that machines can interpret. This capability is indispensable in various applications such as medical imaging, autonomous vehicles, and object detection. By segmenting images, we can identify and extract specific objects, delineate boundaries, and even classify regions based on their content.

Creating models that excel at segmentation is not just about accurate delineation; it’s about empowering machines to understand and interpret visual data with precision and efficiency. A well-designed segmentation model can significantly enhance the performance of downstream tasks, leading to more robust and intelligent systems.

Among the myriad of segmentation techniques and models that have emerged over the years, U-Net stands out as a state-of-the-art solution that has revolutionized the field. Developed by researchers at the Computer Science Department of the University of Freiburg in 2015, U-Net has gained widespread acclaim.

For this tutorial, I am going to do the training using Kaggle Notebooks, since it allows you to use powerful GPU at no cost. In my case, I am using a GPU P100.

2.1) What is U-Net?

U-Net is a convolutional neural network (CNN) architecture that was specifically designed for biomedical image segmentation tasks. Developed in 2015, U-Net has become one of the go-to architectures for various segmentation tasks due to its effectiveness and efficiency. You can find the original paper here.

The U-Net architecture is characterized by its U-shaped structure, which gives it its name. It consists of an encoding path and a decoding path.

  • Encoding Path: This part of the network captures the context of the input image by using a series of convolutional and max-pooling layers to downsample the spatial dimensions. It “contracs” the original images.
  • Decoding Path: The decoding path uses upsampling and convolutional layers to produce a segmentation map that has the same spatial dimensions as the input image. It “expands” the contracted images.

The architecture is shown in the Figure 1.

Mastering U-Net: A Step-by-Step Guide to Segmentation from Scratch with PyTorch (2)

U-Net’s strength in segmentation comes from its use of skip connections, (grey arrows in the Figure 1) which connect the encoding and decoding paths by merging features. This helps retain spatial details lost during downsampling, preserving the image’s local and global context. By maintaining this spatial information, U-Net achieves more accurate segmentation masks. The skip connections assist the network in grasping the relationships between image parts, leading to improved segmentation results.

Now let’s us code the different components of the U-Net for using it!

2.2) Coding the U-Net

The required libraries to run this notebook are:

import copy
import os
import random
import shutil
import zipfile
from math import atan2, cos, sin, sqrt, pi, log

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from numpy import linalg as LA
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from tqdm import tqdm

First let’s define the Double Convolution that is repeated in each step (blue arrow). As it is shown in the picture, it consist on two convolutions of 3x3 followed by ReLU activation:

class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_op = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.conv_op(x)

Now let’s define the downsample part. This correspond to the left part of the figure (encoding path), where we do the Double Convolutions and the Maxpooling (red arrow)

class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = DoubleConv(in_channels, out_channels)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

def forward(self, x):
down = self.conv(x)
p = self.pool(down)

return down, p

U-Net architecture includes skip connections which allow for the fusion of low-level and high-level features, aiding in better localization. In this downsampling part of the architecture, before doing the MaxPooling, we save the convolutioned tensor. That convolutioned tensor is later on concatenated with an upsampled tensor with its own dimension. In the code, this can be seen that the createdDownSample class return two variables down and p.

Lastly, we define the upsampling part. This correspond to the right part of the figure (decoding path). This is done with a deconvolution (green arrow) followed by the double convolution. As we can appreciate, there are four copy and a crop (gray arrow), once before every MaxPooling:

class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x1, x2], 1)
return self.conv(x)

Here we can ask ourselves why Upsample receive two tensors meanwhile Downsample only receives one. This is because Downsample returns two variables p and down and the latter is saved to be later on concatenated with the output of the Upsample class (i.e., the nn.ConvTranspose2d). Downsampleonly receives one because the skip connections are not applied in the encoding path, only in the decoding one.

Now let’s us combine all these classes in the UNet architecture, as explained:

class UNet(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.down_convolution_1 = DownSample(in_channels, 64)
self.down_convolution_2 = DownSample(64, 128)
self.down_convolution_3 = DownSample(128, 256)
self.down_convolution_4 = DownSample(256, 512)

self.bottle_neck = DoubleConv(512, 1024)

self.up_convolution_1 = UpSample(1024, 512)
self.up_convolution_2 = UpSample(512, 256)
self.up_convolution_3 = UpSample(256, 128)
self.up_convolution_4 = UpSample(128, 64)

self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

def forward(self, x):
down_1, p1 = self.down_convolution_1(x)
down_2, p2 = self.down_convolution_2(p1)
down_3, p3 = self.down_convolution_3(p2)
down_4, p4 = self.down_convolution_4(p3)

b = self.bottle_neck(p4)

up_1 = self.up_convolution_1(b, down_4)
up_2 = self.up_convolution_2(up_1, down_3)
up_3 = self.up_convolution_3(up_2, down_2)
up_4 = self.up_convolution_4(up_3, down_1)

out = self.out(up_4)
return out

A graphical explanation of the class UNet is shown in this picture:

Mastering U-Net: A Step-by-Step Guide to Segmentation from Scratch with PyTorch (3)

We are going to test our model with a dummy object of size [1,3,512,512] expecting to obtain a output object size of [1, 10, 512, 512].

input_image = torch.rand((1,3,512,512))
model = UNet(3,10)
output = model(input_image)
print(output.size())
# You should get torch.Size([1, 10, 512, 512]) as a result

Now that the architecture is explained (and hopefully understood!) let’s train it!

Now let’s test our model. In this case, as we are doing a segmentation between a figure and the background, the num_classes=1. We will train a model using the Carvana Dataset.

The Carvana Dataset has the following elements:

  • /train/ — this folder contains the training set images
  • /test/ — this folder contains the test set images. You must predict the mask (in run-length encoded format) for each of the images in this folder
  • /train_masks/ — this folder contains the training set masks in .gif format
  • train_masks.csv — for convenience, this files gives a run-length encoded version of the training set masks.
  • sample_submission.csv — shows the correct submission format
  • metadata.csv — contains basic information about all the cars in the dataset. Note that some values are missing.

We are going to use only the /train/ and /train_masks/. I will add a variable called “limit” in case you want to load few images to run a fast experiment.

class CarvanaDataset(Dataset):
def __init__(self, root_path, limit=None):
self.root_path = root_path
self.limit = limit
self.images = sorted([root_path + "/train/" + i for i in os.listdir(root_path + "/train/")])[:self.limit]
self.masks = sorted([root_path + "/train_masks/" + i for i in os.listdir(root_path + "/train_masks/")])[:self.limit]

self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()])

if self.limit is None:
self.limit = len(self.images)

def __getitem__(self, index):
img = Image.open(self.images[index]).convert("RGB")
mask = Image.open(self.masks[index]).convert("L")

return self.transform(img), self.transform(mask)

def __len__(self):
return min(len(self.images), self.limit)

I am doing the training in Kaggle Notebook. For loading the data there, you can go to Input > Add Input > search “Carvana Image Masking Challenge”. By doing so, you will load the dataset by executing this command:


print(os.listdir("../input/carvana-image-masking-challenge/"))

DATASET_DIR = '../input/carvana-image-masking-challenge/'
WORKING_DIR = '/kaggle/working/'

To extract the elements loaded in zip, you can run this snippet:

if len(os.listdir(WORKING_DIR)) <= 1:

with zipfile.ZipFile(DATASET_DIR + 'train.zip', 'r') as zip_file:
zip_file.extractall(WORKING_DIR)

with zipfile.ZipFile(DATASET_DIR + 'train_masks.zip', 'r') as zip_file:
zip_file.extractall(WORKING_DIR)

print(
len(os.listdir(WORKING_DIR + 'train')),
len(os.listdir(WORKING_DIR + 'train_masks'))
)

And now you are ready to use the class CarvanaDataset

train_dataset = CarvanaDataset(WORKING_DIR)

generator = torch.Generator().manual_seed(25)

Now we are going to split the data into training, validation and testing. First of all, we will take 80% of the data for training and 20% for testing.

train_dataset, test_dataset = random_split(train_dataset, [0.8, 0.2], generator=generator)

We’ll allocate 20% of the data for testing, splitting it into validation and testing subsets. We’ll use 50% of this 20% for testing and the remaining 50% for validation.

test_dataset, val_dataset = random_split(test_dataset, [0.5, 0.5], generator=generator)

In summary, our data will be divided as follows: 80% for training, 10% for testing, and 10% for validation.

Running this experiment with CUDA is crucial for efficient speed and memory management. That’s why we’re using a Kaggle Notebook, where we can utilize a GPU P100 for free.

device = "cuda" if torch.cuda.is_available() else "cpu"

if device = "cuda":
num_workers = torch.cuda.device_count() * 4

We’re setting up the DataLoaders for training, validation, and testing phases. Given the size of our dataset, we’ve chosen a batch size of 8 to prevent GPU memory exhaustion. Additionally, we’re keeping the pin_memory parameter set to False to avoid potential memory issues. While setting pin_memory to True might offer faster data transfer to the GPU, it can also lead to memory allocation problems.

Next, we set up our model using the AdamW optimizer and the BCEWithLogitsLoss loss criterion.

LEARNING_RATE = 3e-4
BATCH_SIZE = 8

train_dataloader = DataLoader(dataset=train_dataset,
num_workers=num_workers, pin_memory=False,
batch_size=BATCH_SIZE,
shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset,
num_workers=num_workers, pin_memory=False,
batch_size=BATCH_SIZE,
shuffle=True)

test_dataloader = DataLoader(dataset=test_dataset,
num_workers=num_workers, pin_memory=False,
batch_size=BATCH_SIZE,
shuffle=True)

model = UNet(in_channels=3, num_classes=1).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

4.1) Evaluating Segmentation Performance with the DICE Metric

When it comes to assessing the performance of a segmentation model, it’s crucial to have reliable metrics that can quantify the accuracy and quality of the segmentation results. One widely used metric for this purpose is the DICE coefficient, also known as the Dice similarity coefficient or Dice index.

The DICE metric provides a measure of the similarity between two sets, in this case, the predicted segmentation and the ground truth segmentation. It calculates the overlap between the two sets, taking into account both the false positives and false negatives.

Mathematically, the DICE score is defined as:

DICE score = 2 * |A ∩ B| / (|A| + |B|)

And can be understood as:
Dice score = 2 * (number of common elements) / (number of elements in set A + number of elements in set B)

The DICE coefficient ranges from 0 to 1, where a value closer to 1 indicates a higher degree of overlap and thus better segmentation performance. A DICE score of 1 would mean a perfect overlap between the predicted and ground truth segmentations, while a score of 0 would indicate no overlap at all.

Mastering U-Net: A Step-by-Step Guide to Segmentation from Scratch with PyTorch (4)

In our case of segmentation, we are comparing two matrixes. Consider matrix A as representing the predicted mask, which is one-dimensional since it has only one channel. This matrix contains elements that are either 0 or 1. When matrix A is multiplied by another matrix, let’s call it matrix B, the reference mask, which also contains elements that are either 0 or 1, the resulting matrix will have a value of 1 at positions i,j only if both matrix A and matrix B have a value of 1 at that same position i,j.

Take into account that this happens because we are using the “*” operator between matrixes. In Python, this operator multiplies element by element, thus achieving that we want to do. This multiplication is not a standard matrix multiplication!

Let’s illustrate this with a practical example:

Mastering U-Net: A Step-by-Step Guide to Segmentation from Scratch with PyTorch (5)

The first scenario involves matrices representing the two images that are highly similar. In these matrices, the positions containing 1s and 0s align closely. In contrast, the second scenario presents matrices where this alignment is not present. As a result, the intersection of the images (which is the sum of the element-wise multiplication of the matrices) will yield a higher value in the first case compared to the second.

With that in mind, we define our metric as follows:

def dice_coefficient(prediction, target, epsilon=1e-07):
prediction_copy = prediction.clone()

prediction_copy[prediction_copy < 0] = 0
prediction_copy[prediction_copy > 0] = 1

intersection = abs(torch.sum(prediction_copy * target))
union = abs(torch.sum(prediction_copy) + torch.sum(target))
dice = (2. * intersection + epsilon) / (union + epsilon)

return dice

Before diving into the training process, it’s crucial to address potential memory issues to ensure smooth execution. When using PyTorch with CUDA for GPU-accelerated training, one common practice is to call torch.cuda.empty_cache(). This function releases all the unused cached memory from the CUDA context, helping to free up GPU memory that might otherwise lead to out-of-memory errors during training.

torch.cuda.empty_cache()

4.2) Training the model

Now, let’s kick off our experiment! We’ll conduct a traditional neural network training using PyTorch.

EPOCHS = 10

train_losses = []
train_dcs = []
val_losses = []
val_dcs = []

for epoch in tqdm(range(EPOCHS)):
model.train()
train_running_loss = 0
train_running_dc = 0

for idx, img_mask in enumerate(tqdm(train_dataloader, position=0, leave=True)):
img = img_mask[0].float().to(device)
mask = img_mask[1].float().to(device)

y_pred = model(img)
optimizer.zero_grad()

dc = dice_coefficient(y_pred, mask)
loss = criterion(y_pred, mask)

train_running_loss += loss.item()
train_running_dc += dc.item()

loss.backward()
optimizer.step()

train_loss = train_running_loss / (idx + 1)
train_dc = train_running_dc / (idx + 1)

train_losses.append(train_loss)
train_dcs.append(train_dc)

model.eval()
val_running_loss = 0
val_running_dc = 0

with torch.no_grad():
for idx, img_mask in enumerate(tqdm(val_dataloader, position=0, leave=True)):
img = img_mask[0].float().to(device)
mask = img_mask[1].float().to(device)

y_pred = model(img)
loss = criterion(y_pred, mask)
dc = dice_coefficient(y_pred, mask)

val_running_loss += loss.item()
val_running_dc += dc.item()

val_loss = val_running_loss / (idx + 1)
val_dc = val_running_dc / (idx + 1)

val_losses.append(val_loss)
val_dcs.append(val_dc)

print("-" * 30)
print(f"Training Loss EPOCH {epoch + 1}: {train_loss:.4f}")
print(f"Training DICE EPOCH {epoch + 1}: {train_dc:.4f}")
print("\n")
print(f"Validation Loss EPOCH {epoch + 1}: {val_loss:.4f}")
print(f"Validation DICE EPOCH {epoch + 1}: {val_dc:.4f}")
print("-" * 30)

# Saving the model
torch.save(model.state_dict(), 'my_checkpoint.pth')

Obtaining:

100%|██████████| 509/509 [06:34<00:00, 1.29it/s]
100%|██████████| 64/64 [00:19<00:00, 3.21it/s]
10%|█ | 1/10 [06:54<1:02:09, 414.40s/it]
------------------------------
Training Loss EPOCH 1: 0.1945
Training DICE EPOCH 1: 0.7723

Validation Loss EPOCH 1: 0.0413
Validation DICE EPOCH 1: 0.9606
------------------------------
100%|██████████| 509/509 [06:33<00:00, 1.29it/s]
100%|██████████| 64/64 [00:19<00:00, 3.20it/s]
20%|██ | 2/10 [13:47<55:09, 413.72s/it]
------------------------------
Training Loss EPOCH 2: 0.0304
Training DICE EPOCH 2: 0.9714

Validation Loss EPOCH 2: 0.0201
Validation DICE EPOCH 2: 0.9812
------------------------------
100%|██████████| 509/509 [06:33<00:00, 1.29it/s]
100%|██████████| 64/64 [00:19<00:00, 3.21it/s]
30%|███ | 3/10 [20:40<48:14, 413.51s/it]
------------------------------
Training Loss EPOCH 3: 0.0535
Training DICE EPOCH 3: 0.9464

Validation Loss EPOCH 3: 0.0658
Validation DICE EPOCH 3: 0.9384
------------------------------
100%|██████████| 509/509 [06:33<00:00, 1.29it/s]
100%|██████████| 64/64 [00:20<00:00, 3.19it/s]
40%|████ | 4/10 [27:34<41:20, 413.42s/it]
------------------------------
Training Loss EPOCH 4: 0.0255
Training DICE EPOCH 4: 0.9763

Validation Loss EPOCH 4: 0.0155
Validation DICE EPOCH 4: 0.9856
------------------------------
100%|██████████| 509/509 [06:33<00:00, 1.29it/s]
100%|██████████| 64/64 [00:20<00:00, 3.19it/s]
50%|█████ | 5/10 [34:27<34:26, 413.40s/it]
------------------------------
Training Loss EPOCH 5: 0.0140
Training DICE EPOCH 5: 0.9870

Validation Loss EPOCH 5: 0.0122
Validation DICE EPOCH 5: 0.9886
------------------------------
100%|██████████| 509/509 [06:33<00:00, 1.29it/s]
100%|██████████| 64/64 [00:20<00:00, 3.19it/s]
60%|██████ | 6/10 [41:20<27:33, 413.41s/it]
------------------------------
Training Loss EPOCH 6: 0.0113
Training DICE EPOCH 6: 0.9894

Validation Loss EPOCH 6: 0.0110
Validation DICE EPOCH 6: 0.9896
------------------------------
100%|██████████| 509/509 [06:33<00:00, 1.29it/s]
100%|██████████| 64/64 [00:20<00:00, 3.19it/s]
70%|███████ | 7/10 [48:14<20:40, 413.40s/it]
------------------------------
Training Loss EPOCH 7: 0.0097
Training DICE EPOCH 7: 0.9909

Validation Loss EPOCH 7: 0.0090
Validation DICE EPOCH 7: 0.9914
------------------------------
100%|██████████| 509/509 [06:33<00:00, 1.29it/s]
100%|██████████| 64/64 [00:19<00:00, 3.20it/s]
80%|████████ | 8/10 [55:07<13:46, 413.40s/it]
------------------------------
Training Loss EPOCH 8: 0.0087
Training DICE EPOCH 8: 0.9917

Validation Loss EPOCH 8: 0.0087
Validation DICE EPOCH 8: 0.9916
------------------------------
100%|██████████| 509/509 [06:33<00:00, 1.29it/s]
100%|██████████| 64/64 [00:20<00:00, 3.20it/s]
90%|█████████ | 9/10 [1:02:01<06:53, 413.51s/it]
------------------------------
Training Loss EPOCH 9: 0.0080
Training DICE EPOCH 9: 0.9923

Validation Loss EPOCH 9: 0.0075
Validation DICE EPOCH 9: 0.9928
------------------------------
100%|██████████| 509/509 [06:33<00:00, 1.29it/s]
100%|██████████| 64/64 [00:20<00:00, 3.17it/s]
100%|██████████| 10/10 [1:08:55<00:00, 413.52s/it]
------------------------------
Training Loss EPOCH 10: 0.0083
Training DICE EPOCH 10: 0.9921

Validation Loss EPOCH 10: 0.0074
Validation DICE EPOCH 10: 0.9928
------------------------------

5.1) Training and validation

We’ll plot both the loss and the DICE score across epochs.

epochs_list = list(range(1, EPOCHS + 1))

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_list, train_losses, label='Training Loss')
plt.plot(epochs_list, val_losses, label='Validation Loss')
plt.xticks(ticks=list(range(1, EPOCHS + 1, 1)))
plt.title('Loss over epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.grid()
plt.tight_layout()

plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs_list, train_dcs, label='Training DICE')
plt.plot(epochs_list, val_dcs, label='Validation DICE')
plt.xticks(ticks=list(range(1, EPOCHS + 1, 1)))
plt.title('DICE Coefficient over epochs')
plt.xlabel('Epochs')
plt.ylabel('DICE')
plt.grid()
plt.legend()

plt.tight_layout()
plt.show()

Mastering U-Net: A Step-by-Step Guide to Segmentation from Scratch with PyTorch (6)

We observe a significant decrease in the loss from epochs 1 to 4, indicating that both the training and validation sets show effective learning. The DICE metric also improves during this period. However, from epochs 4 to 10, the loss and the DICE score remain nearly unchanged, suggesting minimal learning progress. Upon closer examination of the loss plot, we can discern some subtle learning taking place. This phase of learning is important as the model begins to refine the finer details in the mask pixels.

epochs_list = list(range(1, EPOCHS + 1))

plt.figure(figsize=(12, 5))
plt.plot(epochs_list, train_losses, label='Training Loss')
plt.plot(epochs_list, val_losses, label='Validation Loss')
plt.xticks(ticks=list(range(1, EPOCHS + 1, 1)))
plt.ylim(0, 0.05)
plt.title('Loss over epochs (zoomed)')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.grid()
plt.tight_layout()

plt.legend()
plt.show()

Mastering U-Net: A Step-by-Step Guide to Segmentation from Scratch with PyTorch (7)

5.2) Test

Now, we’ll evaluate the trained model using the test images to ensure it hasn’t overfit. While we’ve largely ruled out overfitting given the close alignment of the validation and training loss functions, it’s crucial to verify that the model performs well on unseen images. We load the parameters model:

model_pth = '/kaggle/working/my_checkpoint.pth'
trained_model = UNet(in_channels=3, num_classes=1).to(device)
trained_model.load_state_dict(torch.load(model_pth, map_location=torch.device(device)))

Now, we’ll test the model in the same manner as we did for the validation set.

test_running_loss = 0
test_running_dc = 0

with torch.no_grad():
for idx, img_mask in enumerate(tqdm(test_dataloader, position=0, leave=True)):
img = img_mask[0].float().to(device)
mask = img_mask[1].float().to(device)

y_pred = trained_model(img)
loss = criterion(y_pred, mask)
dc = dice_coefficient(y_pred, mask)

test_running_loss += loss.item()
test_running_dc += dc.item()

test_loss = test_running_loss / (idx + 1)
test_dc = test_running_dc / (idx + 1)

Obtaining test_loss=0.00772 and test_dc=0.99269 that seems really good metrics. Let’s visualize the results! For doing so, we define the next function.

def random_images_inference(image_tensors, mask_tensors, image_paths, model_pth, device):
model = UNet(in_channels=3, num_classes=1).to(device)
model.load_state_dict(torch.load(model_pth, map_location=torch.device(device)))

transform = transforms.Compose([
transforms.Resize((512, 512))
])

# Iterate for the images, masks and paths
for image_pth, mask_pth, image_paths in zip(image_tensors, mask_tensors, image_paths):
# Load the image
img = transform(image_pth)

# Predict the imagen with the model
pred_mask = model(img.unsqueeze(0))
pred_mask = pred_mask.squeeze(0).permute(1,2,0)

# Load the mask to compare
mask = transform(mask_pth).permute(1, 2, 0).to(device)

print(f"Image: {os.path.basename(image_paths)}, DICE coefficient: {round(float(dice_coefficient(pred_mask, mask)),5)}")

# Show the images
img = img.cpu().detach().permute(1, 2, 0)
pred_mask = pred_mask.cpu().detach()
pred_mask[pred_mask < 0] = 0
pred_mask[pred_mask > 0] = 1

plt.figure(figsize=(15, 16))
plt.subplot(131), plt.imshow(img), plt.title("original")
plt.subplot(132), plt.imshow(pred_mask, cmap="gray"), plt.title("predicted")
plt.subplot(133), plt.imshow(mask, cmap="gray"), plt.title("mask")
plt.show()

We load some random images of the test_dataloader

n = 10

image_tensors = []
mask_tensors = []
image_paths = []

for _ in range(n):
random_index = random.randint(0, len(test_dataloader.dataset) - 1)
random_sample = test_dataloader.dataset[random_index]

image_tensors.append(random_sample[0])
mask_tensors.append(random_sample[1])
image_paths.append(random_sample[2])

And now we plot the results. In this case we will only show 1, but by running the notebook you can see as many as you want.

model_path = '/kaggle/working/my_checkpoint.pth'

random_images_inference(image_tensors, mask_tensors, image_paths, model_pàth, device="cpu")

Mastering U-Net: A Step-by-Step Guide to Segmentation from Scratch with PyTorch (8)

We see that the predicted mask is really similar to the mask to predict!

In conclusion, this article has provided an insightful journey into image segmentation, introducing the concept of U-Net as a powerful architecture tailored for this task. We walked through the practical steps of coding a U-Net model and applying it to the Carvana Dataset for segmentation.

To quantify the accuracy of our segmentation, we adopted the DICE score as our evaluation metric. With DICE scores consistently exceeding 0.99, our experiment strongly validates the effectiveness of the U-Net model in accurately predicting masks. This high degree of similarity between the predicted masks and the ground truth underscores U-Net’s robustness and reliability in image segmentation tasks.

I recommend you to download the notebook from this link and import it in Kaggle Notebook! There you will be able to run the experiment without overloading the memory of your local device.

Thank you for taking the time to read!
If you found the story helpful, feel free to give a clap👏, and please don’t hesitate to reach out if anything was unclear!

Follow me to join me in this learning journey!

https://arxiv.org/abs/1505.04597v1

Mastering U-Net: A Step-by-Step Guide to Segmentation from Scratch with PyTorch (2024)
Top Articles
Latest Posts
Recommended Articles
Article information

Author: Amb. Frankie Simonis

Last Updated:

Views: 5777

Rating: 4.6 / 5 (76 voted)

Reviews: 83% of readers found this page helpful

Author information

Name: Amb. Frankie Simonis

Birthday: 1998-02-19

Address: 64841 Delmar Isle, North Wiley, OR 74073

Phone: +17844167847676

Job: Forward IT Agent

Hobby: LARPing, Kitesurfing, Sewing, Digital arts, Sand art, Gardening, Dance

Introduction: My name is Amb. Frankie Simonis, I am a hilarious, enchanting, energetic, cooperative, innocent, cute, joyous person who loves writing and wants to share my knowledge and understanding with you.