U-Net is a widely used deep learning architecture that was first introduced in the “U-Net: Convolutional Networks for Biomedical Image Segmentation” paper. The primary purpose of this architecture was to address the challenge of limited annotated data in the medical field. This network was designed to effectively leverage a smaller amount of data while maintaining speed and accuracy.
U-Net Architecture:
The architecture of U-Net is unique in that it consists of a contracting path and an expansive path. The contracting path contains encoder layers that capture contextual information and reduce the spatial resolution of the input, while the expansive path contains decoder layers that decode the encoded data and use the information from the contracting path via skip connections to generate a segmentation map.
The contracting path in U-Net is responsible for identifying the relevant features in the input image. The encoder layers perform convolutional operations that reduce the spatial resolution of the feature maps while increasing their depth, thereby capturing increasingly abstract representations of the input. This contracting path is similar to the feedforward layers in other convolutional neural networks. On the other hand, the expansive path works on decoding the encoded data and locating the features while maintaining the spatial resolution of the input. The decoder layers in the expansive path upsample the feature maps, while also performing convolutional operations. The skip connections from the contracting path help to preserve the spatial information lost in the contracting path, which helps the decoder layers to locate the features more accurately.
U-Net Architecture
Figure 1 illustrates how the U-Net network converts a grayscale input image of size 572×572×1 into a binary segmented output map of size 388×388×2. We can notice that the output size is smaller than the input size because no padding is being used. However, if we use padding, we can maintain the input size. During the contracting path, the input image is progressively reduced in height and width but increased in the number of channels. This increase in channels allows the network to capture high-level features as it progresses down the path. At the bottleneck, a final convolution operation is performed to generate a 30×30×1024 shaped feature map. The expansive path then takes the feature map from the bottleneck and converts it back into an image of the same size as the original input. This is done using upsampling layers, which increase the spatial resolution of the feature map while reducing the number of channels. The skip connections from the contracting path are used to help the decoder layers locate and refine the features in the image. Finally, each pixel in the output image represents a label that corresponds to a particular object or class in the input image. In this case, the output map is a binary segmentation map where each pixel represents a foreground or background region.
Build the Model:
Next, we will implement the U-Net architecture using Python 3 and the TensorFlow library. The implementation can be divided into three parts. First, we will define the encoder block used in the contraction path. This block consists of two 3×3 convolution layers followed by a ReLU activation layer and a 2×2 max pooling layer. The second part is the decoder block, which takes the feature map from the lower layer, upconverts it, crops and concatenates it with the encoder data of the same level, and then performs two 3×3 convolution layers followed by ReLU activation. The third part is defining the U-Net model using these blocks.
Encoder
Here’s the code for the encoder block:
Python3
def
encoder_block(inputs, num_filters):
# Convolution with 3x3 filter followed by ReLU activation
x
=
tf.keras.layers.Conv2D(num_filters,
3
,
padding
=
'valid'
)(inputs)
x
=
tf.keras.layers.Activation(
'relu'
)(x)
# Convolution with 3x3 filter followed by ReLU activation
x
=
tf.keras.layers.Conv2D(num_filters,
3
,
padding
=
'valid'
)(x)
x
=
tf.keras.layers.Activation(
'relu'
)(x)
# Max Pooling with 2x2 filter
x
=
tf.keras.layers.MaxPool2D(pool_size
=
(
2
,
2
),
strides
=
2
)(x)
return
x
Decoder
Now defining the decoder.
Python3
def
decoder_block(inputs, skip_features, num_filters):
# Upsampling with 2x2 filter
x
=
tf.keras.layers.Conv2DTranspose(num_filters,
(
2
,
2
),
strides
=
2
,
padding
=
'valid'
)(inputs)
# Copy and crop the skip features
# to match the shape of the upsampled input
skip_features
=
tf.image.resize(skip_features,
size
=
(x.shape[
1
],
x.shape[
2
]))
x
=
tf.keras.layers.Concatenate()([x, skip_features])
# Convolution with 3x3 filter followed by ReLU activation
x
=
tf.keras.layers.Conv2D(num_filters,
3
,
padding
=
'valid'
)(x)
x
=
tf.keras.layers.Activation(
'relu'
)(x)
# Convolution with 3x3 filter followed by ReLU activation
x
=
tf.keras.layers.Conv2D(num_filters,
3
, padding
=
'valid'
)(x)
x
=
tf.keras.layers.Activation(
'relu'
)(x)
return
x
U-Net
Using these blocks and defining a U-Net model and printing model summary.
Python3
# Unet code
import
tensorflow as tf
def
unet_model(input_shape
=
(
256
,
256
,
3
), num_classes
=
1
):
inputs
=
tf.keras.layers.
Input
(input_shape)
# Contracting Path
s1
=
encoder_block(inputs,
64
)
s2
=
encoder_block(s1,
128
)
s3
=
encoder_block(s2,
256
)
s4
=
encoder_block(s3,
512
)
# Bottleneck
b1
=
tf.keras.layers.Conv2D(
1024
,
3
, padding
=
'valid'
)(s4)
b1
=
tf.keras.layers.Activation(
'relu'
)(b1)
b1
=
tf.keras.layers.Conv2D(
1024
,
3
, padding
=
'valid'
)(b1)
b1
=
tf.keras.layers.Activation(
'relu'
)(b1)
# Expansive Path
s5
=
decoder_block(b1, s4,
512
)
s6
=
decoder_block(s5, s3,
256
)
s7
=
decoder_block(s6, s2,
128
)
s8
=
decoder_block(s7, s1,
64
)
# Output
outputs
=
tf.keras.layers.Conv2D(num_classes,
1
,
padding
=
'valid'
,
activation
=
'sigmoid'
)(s8)
model
=
tf.keras.models.Model(inputs
=
inputs,
outputs
=
outputs,
name
=
'U-Net'
)
return
model
if
__name__
=
=
'__main__'
:
model
=
unet_model(input_shape
=
(
572
,
572
,
3
), num_classes
=
2
)
model.summary()
Output:
Model: "U-Net"__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_6 (InputLayer) [(None, 572, 572, 3 0 [] )] conv2d_95 (Conv2D) (None, 570, 570, 64 1792 ['input_6[0][0]'] ) activation_90 (Activation) (None, 570, 570, 64 0 ['conv2d_95[0][0]'] ) conv2d_96 (Conv2D) (None, 568, 568, 64 36928 ['activation_90[0][0]'] ) activation_91 (Activation) (None, 568, 568, 64 0 ['conv2d_96[0][0]'] ) max_pooling2d_20 (MaxPooling2D (None, 284, 284, 64 0 ['activation_91[0][0]'] ) ) conv2d_97 (Conv2D) (None, 282, 282, 12 73856 ['max_pooling2d_20[0][0]'] 8) activation_92 (Activation) (None, 282, 282, 12 0 ['conv2d_97[0][0]'] 8) conv2d_98 (Conv2D) (None, 280, 280, 12 147584 ['activation_92[0][0]'] 8) activation_93 (Activation) (None, 280, 280, 12 0 ['conv2d_98[0][0]'] 8) max_pooling2d_21 (MaxPooling2D (None, 140, 140, 12 0 ['activation_93[0][0]'] ) 8) conv2d_99 (Conv2D) (None, 138, 138, 25 295168 ['max_pooling2d_21[0][0]'] 6) activation_94 (Activation) (None, 138, 138, 25 0 ['conv2d_99[0][0]'] 6) conv2d_100 (Conv2D) (None, 136, 136, 25 590080 ['activation_94[0][0]'] 6) activation_95 (Activation) (None, 136, 136, 25 0 ['conv2d_100[0][0]'] 6) max_pooling2d_22 (MaxPooling2D (None, 68, 68, 256) 0 ['activation_95[0][0]'] ) conv2d_101 (Conv2D) (None, 66, 66, 512) 1180160 ['max_pooling2d_22[0][0]'] activation_96 (Activation) (None, 66, 66, 512) 0 ['conv2d_101[0][0]'] conv2d_102 (Conv2D) (None, 64, 64, 512) 2359808 ['activation_96[0][0]'] activation_97 (Activation) (None, 64, 64, 512) 0 ['conv2d_102[0][0]'] max_pooling2d_23 (MaxPooling2D (None, 32, 32, 512) 0 ['activation_97[0][0]'] ) conv2d_103 (Conv2D) (None, 30, 30, 1024 4719616 ['max_pooling2d_23[0][0]'] ) activation_98 (Activation) (None, 30, 30, 1024 0 ['conv2d_103[0][0]'] ) conv2d_104 (Conv2D) (None, 28, 28, 1024 9438208 ['activation_98[0][0]'] ) activation_99 (Activation) (None, 28, 28, 1024 0 ['conv2d_104[0][0]'] ) conv2d_transpose_20 (Conv2DTra (None, 56, 56, 512) 2097664 ['activation_99[0][0]'] nspose) tf.image.resize_20 (TFOpLambda (None, 56, 56, 512) 0 ['max_pooling2d_23[0][0]'] ) concatenate_20 (Concatenate) (None, 56, 56, 1024 0 ['conv2d_transpose_20[0][0]', ) 'tf.image.resize_20[0][0]'] conv2d_105 (Conv2D) (None, 54, 54, 512) 4719104 ['concatenate_20[0][0]'] activation_100 (Activation) (None, 54, 54, 512) 0 ['conv2d_105[0][0]'] conv2d_106 (Conv2D) (None, 52, 52, 512) 2359808 ['activation_100[0][0]'] activation_101 (Activation) (None, 52, 52, 512) 0 ['conv2d_106[0][0]'] conv2d_transpose_21 (Conv2DTra (None, 104, 104, 25 524544 ['activation_101[0][0]'] nspose) 6) tf.image.resize_21 (TFOpLambda (None, 104, 104, 25 0 ['max_pooling2d_22[0][0]'] ) 6) concatenate_21 (Concatenate) (None, 104, 104, 51 0 ['conv2d_transpose_21[0][0]', 2) 'tf.image.resize_21[0][0]'] conv2d_107 (Conv2D) (None, 102, 102, 25 1179904 ['concatenate_21[0][0]'] 6) activation_102 (Activation) (None, 102, 102, 25 0 ['conv2d_107[0][0]'] 6) conv2d_108 (Conv2D) (None, 100, 100, 25 590080 ['activation_102[0][0]'] 6) activation_103 (Activation) (None, 100, 100, 25 0 ['conv2d_108[0][0]'] 6) conv2d_transpose_22 (Conv2DTra (None, 200, 200, 12 131200 ['activation_103[0][0]'] nspose) 8) tf.image.resize_22 (TFOpLambda (None, 200, 200, 12 0 ['max_pooling2d_21[0][0]'] ) 8) concatenate_22 (Concatenate) (None, 200, 200, 25 0 ['conv2d_transpose_22[0][0]', 6) 'tf.image.resize_22[0][0]'] conv2d_109 (Conv2D) (None, 198, 198, 12 295040 ['concatenate_22[0][0]'] 8) activation_104 (Activation) (None, 198, 198, 12 0 ['conv2d_109[0][0]'] 8) conv2d_110 (Conv2D) (None, 196, 196, 12 147584 ['activation_104[0][0]'] 8) activation_105 (Activation) (None, 196, 196, 12 0 ['conv2d_110[0][0]'] 8) conv2d_transpose_23 (Conv2DTra (None, 392, 392, 64 32832 ['activation_105[0][0]'] nspose) ) tf.image.resize_23 (TFOpLambda (None, 392, 392, 64 0 ['max_pooling2d_20[0][0]'] ) ) concatenate_23 (Concatenate) (None, 392, 392, 12 0 ['conv2d_transpose_23[0][0]', 8) 'tf.image.resize_23[0][0]'] conv2d_111 (Conv2D) (None, 390, 390, 64 73792 ['concatenate_23[0][0]'] ) activation_106 (Activation) (None, 390, 390, 64 0 ['conv2d_111[0][0]'] ) conv2d_112 (Conv2D) (None, 388, 388, 64 36928 ['activation_106[0][0]'] ) activation_107 (Activation) (None, 388, 388, 64 0 ['conv2d_112[0][0]'] ) conv2d_113 (Conv2D) (None, 388, 388, 2) 130 ['activation_107[0][0]'] ==================================================================================================Total params: 31,031,810Trainable params: 31,031,810Non-trainable params: 0__________________________________________________________________________________________________
Apply to an Image
Input Image:
Input image
Python3
import
numpy as np
from
PIL
import
Image
from
tensorflow.keras.preprocessing
import
image
# Load the image
img
=
Image.
open
(
'cat.png'
)
# Preprocess the image
img
=
img.resize((
572
,
572
))
img_array
=
image.img_to_array(img)
img_array
=
np.expand_dims(img_array[:,:,:
3
], axis
=
0
)
img_array
=
img_array
/
255.
# Load the model
model
=
unet_model(input_shape
=
(
572
,
572
,
3
), num_classes
=
2
)
# Make predictions
predictions
=
model.predict(img_array)
# Convert predictions to a numpy array and resize to original image size
predictions
=
np.squeeze(predictions, axis
=
0
)
predictions
=
np.argmax(predictions, axis
=
-
1
)
predictions
=
Image.fromarray(np.uint8(predictions
*
255
))
predictions
=
predictions.resize((img.width, img.height))
# Save the predicted image
predictions.save(
'predicted_image.jpg'
)
predictions
Output:
1/1 [==============================] - 2s 2s/step
Predicted Image
Applications:
The versatility and flexibility has enabled us to use this idea in various other domains beyond biomedical image segmentation. Some of the popular application involves image denoising, image-to-image translation, image super-resolution, object detection, and NLP. You can also explore some of these applications in the following articles:
- Image Segmentation Using TensorFlow
- Image-to-Image Translation using Pix2Pix
Previous Article
Benefits of a Hierarchical Network in Cisco
Next Article
Bidirectional LSTM in NLP