U²-Net For Image Segmentation
Introduction
In the paper, A Novel Cascade Machine Learning Pipeline for Alzheimer’s Disease Identification and Prediction, Zhou K. et al. aimed to diagnose Alzheimer’s disease (AD) by segmenting and classifying the hippocampus as belonging to either an AD patient or a normal control. The images used were Coronal T1-weighted images from 183 AD patients and 230 normal individuals.
Significantly, Coronal T1-weighted images are acquired as 2D sequences in the coronal plane. Compared to their 3D counterparts, these images have lower resolution but offer quicker acquisition times. Coronal T1 images are often used in routine scans to focus on specific brain regions. The paper acknowledges that while there has been considerable machine learning (ML) research on diagnosing AD through hippocampal atrophy, most experiments have utilized 3D T1-sequence images.
The hippocampus, part of the brain’s limbic system, regulates emotional and behavioral responses, contributing to learning, memory, and emotions. Hippocampal atrophy, the shrinkage of the hippocampus, is considered a promising biomarker for diagnosing AD.
The computer-aided diagnosis pipeline used in this paper included preprocessing, hippocampal segmentation, radiomic feature extraction and selection, and classification. This post focuses on the segmentation component of the pipeline.
Segmentation
The goal was to segment the hippocampus from Coronal T1-weighted images. Data was collected from 187 AD patients aged 50–75 years at Huashan Hospital in Shanghai, China. A qualified neurologist diagnosed the patients according to standard criteria for amnestic AD. A control group of 230 individuals was also selected.
The images were manually annotated by neuroradiologists using ITK-SNAP software. The dataset was then randomly divided into a training set (334 images) and a test set (83 images). Before implementing the U²-Net algorithm, the training data was augmented through flipping, translation, and rotation, resulting in 668 training images.
The segmentation model consisted of a modified U²-Net, a 5-stage encoder-decoder in a two-level nested structure. The top level featured a large U-shaped structure with five stages. Another important modification was, including a deep supervision mechanism to improve segmentation performance.
To understand the modifications made by Zhou K. et al., it is important first to understand the original design of the U²-Net as described in U²-Net: Going Deeper with Nested U-Structure for Salient Object Detection by Qin X. et al.
Understanding the U²-Net
The U²-Net was originally designed for Salient Object Detection (SOD). The goal of SOD is to segment the most visually “important” or “attention-grabbing” objects in an image. SOD mimics the attention a human would give to objects.
Unlike semantic segmentation (SS), which segments images based on local features, SOD uses both local features (edges, textures, small regions) and global features (color, shape). For example, semantic segmentation would segment entire objects like cats or humans, while SOD would segment the cats and dogs while focusing on finer contours or edges like whiskers.
Problem with SOD algorithms Pre-U²-Net
- Reliance on Classification Backbones:
Most SOD algorithms use deep features from backbones like AlexNet, VGG, ResNet, or ResNext, which were originally designed for image classification. These backbones extract high-level semantic features suitable for classifying entire images but are less effective for capturing local details like edges or textures critical for saliency detection. - Complexity, Memory, and Computation Cost:
Existing SOD algorithms often add additional feature aggregation modules to capture multi-level saliency, increasing complexity and computational expense. - Loss of Resolution:
Traditional backbones achieve deeper architectures by downsampling images through pooling or strided convolutions, which reduces the resolution of feature maps. While this is computationally efficient, it compromises the high-resolution details necessary for accurate saliency detection.
The goal of Qin X. et al. was to maintain high-resolution feature maps while reducing memory and computational costs.
Architecture of the U²-Net
The U²-Net employs a two-level nested U-structure:
- Bottom Level:
The Residual U-block (RSU-block) is designed to extract local, intra-stage, multi-scale features without degrading feature map resolution. - Top Level:
A U-Net-like structure aggregates inter-stage, multi-level features.
Thus, U²-Net combines several RSU blocks with a U-Net-like structure (encoders, decoders, and a saliency map fusion module). A detailed description of both the top and bottom levels is below.
ReSidual U-Block (RSU)
The RSU captures intra-stage, multi-scale features. Its structure is defined as (Cin, M, Cout), where:
- L: Number of layers in the encoder-decoder block.
- Cin, Cout: Input and output channels.
- M: Number of channels in the internal layers of the RSU.
Components of the RSU:
- Input Convolutional Layer:
Transforms the input feature map (H, W, Cin) into an intermediate feature map F1(x)F_1(x)F1(x) with Cout channels, extracting local features. - U-Net-Like Symmetric Encoder-Decoder:
Takes the intermediate feature maps as input and learns to extract and encode multi-scale contextual information. At this stage, we reconstruct high-resolution feature maps through upsampling, concatenation, and convolution. The larger the L, the more pooling, the larger the range of receptive field, and the richer the local and global features extracted. - Residual Connection:
Combines local and multi-scale features through summation: F1(x)+U(F1(x))F_1(x) + U(F_1(x))F1(x)+U(F1(x)).
RSU blocks enhance segmentation accuracy by preserving fine details and spatial resolution.
U²-Net Architecture
Research developments involving stacking multiple U-net architectures for different tasks have been explored. They typically build cascades of the U-net (U x n-Net), where n is the number of repeated U-Net models. However, the computation and memory costs get magnified by n.
U²-net rather nests the U-net.
At the top level of the U²-net, there are 11 stages with each stage filled by a well-configured RSU. Its nested nature enables the extraction of intra-stage multi-scale features and aggregation of inter-stage multi-level features.
11 Stages — 6 encoders, 5 decoders, a saliency map fusion module (attached to the decoder stages, and the last encoder(En_6) stage). The saliency map fusion model generates saliency probability maps.
Saliency probability maps are binary images generated that represent the regions of interest (segmented regions).
Encoders
En_1–4 use RSUs with decreasing layers (7–4). L is configured according to the spatial resolution of the input feature maps. For feature maps with large H and W, we use greater L to capture more large-scale features.
Resolution for feature maps of En_5–6 are low and a further downsampling would lead to loss of useful information. In both 5 and 6, an RSU_4F is used instead. F means that the RSU is dilated; upsampling and pooling are replaced with dilated convolutions.
Dilated convolutions increase the receptive fields without increasing the parameters or computational cost. They do these while preserving spatial resolution thus making them ideal for tasks like segmentation.
Decoder
Decoder stages have a similar structure to their symmetrical encoder stage with En_6 as the reference. Each decoder takes a concatenated version of the upsampled feature maps from the previous stage and those from their symmetrical encoders as input.
Saliency map fusion module
Attached to the decoder stages and the last encoder stage. The saliency map model generates saliency probability maps. U²- net generates 6 side output saliency maps from En_6 to De_1 by a 3 x 3 convolution layer and a sigmoid function. It upsamples the maps to the size of the input image and fuses them with a concatenation operation, followed by a 1 x 1 conv operation and a sigmoid function to generate the final saliency probability map S_{fuse}.
Advantages of the U² Net for SOD
- Captures more contextual information from different scales thanks to the mixture of receptive fields of different sizes of the RSU.
- Increases the depth of the whole architecture without significantly increasing the computational cost because of the pooling done in the RSU blocks.
- Trains a deep network from scratch without using backbones for classification tasks.
Could you hypothesize, why Zhou, K et al., used a 5-stage U²-Net instead of the 11 used in the original paper?
My team member thinks the 5 stages were because of the size of the dataset and a deeper network would result in problems like overfitting hence the fewer stages.
Inferences: Hypothesising the efficiency of the 5-stage over the 11-stage U²-Net
- Model complexity: By reducing the number of stages to 5, the model complexity decreases and makes it more suitable for training on limited data (in our case 668 training images)
- Task Specific requirements: SOD involves identifying complex objects in high-resolution images. For hippocampal segmentation, the target image is relatively small and localized.
- computational efficiency: the shallower the network, the easier it is to train them.
- Resolution considerations: coronal T1-images have lower resolution and fewer stages are required to capture relevant features at different scales hence no need for a deeper network
- Avoiding redundancy: in our task, the ROI was relatively well-defined and the surrounding contexts provided limited additional information, deeper architecture might not have yielded better results. With just 5 stages, we tried to balance between efficiency and performance without redundant feature extraction.
Results
The segmentation results qualitatively and quantitatively improved after the data augmentation process according to the paper.
This post explored the segmentation pipeline presented by Zhou et al. in their paper. We dived deeper into the U²-Net architecture as proposed by Qin et al, examining the structure and functionality of the RSU blocks, highlighting the advantages of the design, and concluding with an analysis of the segmentation results.
Resources
- Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O. R., & Jagersand, M. (2020). U2-Net: Going deeper with nested U-structure for salient object detection. Pattern recognition, 106, 107404.
- Zhou, K., Piao, S., Liu, X., Luo, X., Chen, H., Xiang, R., & Geng, D. (2023). A novel cascade machine learning pipeline for Alzheimer’s disease identification and prediction. Frontiers in Aging Neuroscience, 14, 1073909.
- OpenAI. (2024). ChatGPT [Large language model].