This repository contains the implementation of a deep learning model for image classification called Dynamic Residual Feature Aggregation Network (DRA-Net). The model is designed to improve the discriminative power and robustness of feature representations by introducing learnable attractors and attention mechanisms.
The DRA-Net consists of two main components:
-
Feature Extractor: A modified ResNet-18 model is used as the feature extractor. It can be initialized with pre-trained weights or trained from scratch. The last layer can be removed to extract features.
-
Dynamic Residual Feature Aggregation Network: This module performs feature aggregation using learnable attractors.
- Attractors are learned to capture discriminative features for each class.
- An Adaptive Weights Network generates attractor weights for each sample.
- A regularization term encourages diversity among attractors by minimizing their pairwise distances.
- The aggregated features are added to the original features through a residual connection, followed by an MLP for classification.
- Functions are provided to compute intra-class compactness and inter-class distance.
The training process consists of two stages:
-
Feature Extractor Pre-training: The feature extractor is pre-trained using the classification loss.
-
Attractor Network Training: The feature extractor is fixed, and the attention attractor network is trained using a weighted sum of the classification loss and the attractor regularization loss.
The model is evaluated on the test set for accuracy, and the feature representations are visualized using t-SNE. Intra-class compactness and inter-class distance metrics are also computed on the test set.
The code uses the CIFAR-10 dataset for training and evaluation. The dataset is automatically downloaded if not found in the specified directory.
- Python 3.x
- PyTorch
- torchvision
- NumPy
- Matplotlib
- scikit-learn
-
Clone the repository:
git clone https://github.com/your-username/DRA-Net.git
-
Install the required dependencies:
pip install -r requirements.txt
-
Run the code:
python main.py
The code will start the training process, evaluate the model on the test set, visualize the features using t-SNE, and compute the intra-class compactness and inter-class distance metrics.
The model achieves high accuracy on the CIFAR-10 dataset and demonstrates improved feature representation quality through visualization and metric evaluation.
This project is licensed under the MIT License. See the LICENSE file for details.
- The ResNet-18 model is used as the feature extractor, which was introduced in the paper "Deep Residual Learning for Image Recognition" by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.
- The CIFAR-10 dataset is used for training and evaluation, which was introduced in the paper "Learning Multiple Layers of Features from Tiny Images" by Alex Krizhevsky.