Christian Mills – Hands-On Hand Gesture Recognition: Fine-Tuning Image Classifiers with PyTorch and the timm library for Beginners

Introduction

Welcome to this hands-on guide to fine-tuning image classifiers with PyTorch and the timm library! Fine-tuning refers to taking a pre-trained model and adjusting its parameters using a new dataset to enhance its performance on a specific task. We can leverage pre-trained models to achieve high performance even when working with limited data and computational resources. The timm library further aids our goal with its wide range of pre-trained models, catering to diverse needs and use cases.

In this tutorial, we develop a hand gesture recognizer. Hand gesture recognition has many real-world applications, ranging from human-computer interaction and sign-language translation to creating immersive gaming experiences. By the end of this tutorial, you will have a practical hand gesture recognizer and a solid foundation to apply to other image classification tasks. You’ll also be able to interact with a model trained with this tutorial’s code through an in-browser demo that runs locally on your computer. Check out the video below for a quick preview.

This guide is structured so that you don’t need a deep understanding of deep learning to complete it. If you follow the instructions, you can make it through! Yet, if you are eager to delve deeper into machine learning and deep learning, I recommend fast.ai’s Practical Deep Learning for Coders course. The course employs a hands-on approach that starts you off training models from the get-go and gradually digs deeper into the foundational concepts.

Let’s dive in and start training our hand gesture classifier!

Getting Started with the Code

The tutorial code is available as a Jupyter Notebook, which you can run locally or in a cloud-based environment like Google Colab. If you’re new to these platforms or need some guidance setting up, I’ve created dedicated tutorials to help you:

No matter your choice of environment, you’ll be well-prepared to follow along with the rest of this tutorial. You can download the notebook from the tutorial’s GitHub repository or open the notebook directly in Google Colab using the links below.

Setting Up Your Python Environment

Before diving into the code, we’ll create a Python environment and install the necessary libraries. Creating a dedicated environment will ensure our project has all its dependencies in one place and does not interfere with other Python projects you may have.

Please note that this section is for readers setting up a local Python environment on their machines. If you’re following this tutorial on a cloud-based platform like Google Colab, the platform already provides an isolated environment with many Python libraries pre-installed. In that case, you may skip this section and directly proceed to the code sections. However, you may still need to install certain libraries specific to this tutorial using similar pip install commands within your notebook. The dedicated Colab Notebook contains the instructions for running it in Google Colab.

Creating a Python Environment

First, we’ll create a Python environment using Conda. Conda is a package manager that can create isolated Python environments. These environments are like sandboxed spaces where you can install Python libraries without affecting the rest of your system.

To create a new Python environment, open a terminal with Conda/Mamba installed and run the following commands:

  • datasets: A library for accessing and sharing datasets for Audio, Computer Vision, and Natural Language Processing (NLP) tasks.
  • jupyter: An open-source web application that allows you to create and share documents that contain live code, equations, visualizations, and narrative text.
  • matplotlib: This package provides a comprehensive collection of visualization tools to create high-quality plots, charts, and graphs for data exploration and presentation.
  • pandas: This package provides fast, powerful, and flexible data analysis and manipulation tools.
  • pillow: The Python Imaging Library adds image processing capabilities.
  • timm: The timm library provides state-of-the-art (SOTA) computer vision models, layers, utilities, optimizers, schedulers, data loaders, augmentations, and training/evaluation scripts.
  • torcheval: A library that provides simple and easy-to-use tooling for evaluating PyTorch models.
  • tqdm: A Python library that provides fast, extensible progress bars for loops and other iterable objects in Python.
  • Jupyter Client: This package contains the reference implementation of the Jupyter protocol. It also provides client and kernel management APIs for working with kernels. We will install an older version than the one included with Jupyter (<8) to avoid an issue that causes the training notebook to freeze during training (link).
  • PyZMQ: This package provides Python bindings for ZeroMQ, a lightweight and fast messaging implementation used by Jupyter Notebooks. We will install an older version than the one included with Jupyter (<25) to avoid an issue that causes the training notebook to freeze during training (link).

To install these additional libraries, we’ll use the following command:

cjm_pil_utils), interact with PyTorch (cjm_pytorch_utils), and work with pandas DataFrames (cjm_pandas_utils):

  • HuggingFace Hub, and this package allows us to load our dataset with a single line of code.
  • matplotlib: We use the matplotlib package to explore the dataset samples and class distribution.
  • NumPy: We’ll use it to store PIL Images as arrays of pixel values.
  • pandas: We use Pandas DataFrame and Series objects to format data as tables.
  • PIL (Pillow): We’ll use it for opening and working with image files.
  • Python Standard Library dependencies: These are built-in modules that come with Python. We’ll use them for various tasks like handling file paths (pathlib.Path), manipulating JSON files (json), random number generation (random), multiprocessing (multiprocessing), mathematical operations (math), copying Python objects (copy), file matching patterns (glob), working with dates and times (datetime), and interacting with the operating system (os).
  • PyTorch dependencies: We’ll use PyTorch’s various modules for building our model, processing data, and training.
  • timm library: We’ll use the timm library to download and prepare a pre-trained model for fine-tuning.
  • tqdm: We use the library to track the progress of longer processes like training.
  • Utility functions: These are helper functions from the packages we installed earlier. They provide shortcuts for routine tasks and keep our code clean and readable.

set_seed function from the cjm_pytorch_utils package.

Setting a Random Number Seed

A fixed seed value is helpful when training deep-learning models for reproducibility, debugging, and comparison. Having reproducible results allows others to confirm your findings. Using a fixed seed can make it easier to find bugs as it ensures the same inputs produce the same outputs. Likewise, using fixed seed values lets you compare performance between models and training parameters. That said, it’s often a good idea to test different seed values to see how your model’s performance varies between them. Also, don’t use a fixed seed value when you deploy the final model.

get_torch_device function from the cjm_pytorch_utils package.

PyTorch can run on either a CPU or a GPU. The get_torch_device function will automatically check if a supported Nvidia or Mac GPU is available. Otherwise, it will use the CPU. We’ll use the device and type variables to ensure all tensors and model weights are on the correct device and have the same data type. Otherwise, we might get errors.

HaGRID (HAnd Gesture Recognition Image Dataset) that I modified for image classification tasks. The dataset contains images for 18 distinct hand gestures and an additional no_gesture class for idle hands. The dataset is approximately 3.8 GB, but you will need about 7.6 GB to store the archive file and extracted dataset.

The following steps demonstrate how to load the dataset from the HuggingFace Hub, inspect the dataset, and visualize some sample images.

Setting the Dataset Path

We’ll first set up the path for our dataset. We’ll construct the HuggingFace Hub dataset name by combining the username and the dataset name. We then define where to cache the dataset locally.

ResNet18 family of models. ResNet 18 models are popular for image classification tasks due to their balance of accuracy and speed.

ResNet 18-D model. This model’s balance of accuracy and speed makes it suitable for real-time applications, such as hand gesture recognition. While this model is a good all-rounder, others may work better for specific applications. For example, some models are designed to run on mobile devices and may sacrifice some accuracy for improved performance. Whatever your requirements are, the timm library likely has a suitable model for your needs. Feel free to try different models and see how they compare.

Inspecting the ResNet18-D Model Configuration

Next, we will inspect the configuration of our chosen model. The model config gives us information about the pretraining process for the model.

trivial augmentation, which applies a single, random transform to each image. This simple method can be highly effective for data augmentation.

However, we’ll need to create a custom version of the TrivialAugmentWide class from PyTorch’s transforms module, as some of the default parameters are not ideal for this dataset. This custom class defines a dictionary of operations for augmenting the images, and we can customize each operation’s parameters.

Trivial Augmentation

DataLoader to create batches. This class fetches a sample from the dataset at a given index and returns the transformed image and its corresponding label index.

DataLoaders, which are used to efficiently create batches of data for the model to process during training.

Training Batch Size

Next, we set the batch size for training. This number indicates how many sample images get fed to the model at once. The larger the batch size, the more GPU memory we need. The current batch size should be fine for most modern GPUs. If you still get an out-of-memory error, try lowering the batch size to 8, then restart the Jupyter Notebook.

Initialize DataLoaders

We initialize the DataLoaders for the training and validation datasets. We’ll set the number of worker processes for loading data to the number of available CPUs.

AdamW as our optimizer, which includes weight decay for regularization, and the OneCycleLR scheduler to adjust the learning rate during training. The one-cycle learning rate policy is a training approach where the learning rate starts low, increases gradually to a maximum, then decreases again, all within a single iteration or epoch, aiming to converge faster and yield better performance.

Multiclass Accuracy for our performance metric as this is a multiclass classification problem where each image falls into one of many classes.

Train the Model

Finally, we can train the model using the train_loop function. Training time will depend on the available hardware. Feel free to take a break if the progress bar indicates it will take a while.

Training usually takes around 1 hour and 20 minutes on the free GPU tier of Google Colab.

resize_img function will scale the image so the smallest dimension is the specified inference size while maintaining the original aspect ratio.

pil_to_tensor function and move it to the device where our model resides (CPU or GPU).

Softmax function to convert these values into probabilities that sum up to 1.

Pandas Series and print it.

Pexels.

  1. tutorial link)
  2. Once you finish training and download the files, turn off hardware acceleration for the Colab Notebook to save GPU time. (tutorial link)

Exploring the In-Browser Demo

You’ve gotten your hands dirty with the code. Now let’s see our fine-tuned model in action! I’ve set up an online demo that allows you to interact with a hand gesture recognizer trained with this tutorial’s code in your web browser. No downloads or installations are required.

The demo includes sample images that you can use to test the model. Try these images first to see how the model interprets different hand gestures. Once you’re ready, you can switch on your webcam to provide live input to the model.

Online demos are a great way to see and share the fruits of your labor and explore ways to apply your hand gesture recognizer in real-world scenarios.

I invite you to share any interesting results or experiences with the demo in the comments below. Whether it’s a tricky input image the model handles or a surprising failure case, I’d love to hear about it!

Check out the demo below, and have fun exploring!

Conclusion

Congratulations on completing this tutorial on fine-tuning image classifiers with PyTorch and the timm library! You’ve taken significant strides in your machine learning journey by creating a practical hand gesture recognizer.

Throughout this tutorial, we’ve covered many topics, including setting up your Python environment, importing necessary dependencies, project initialization, dataset loading and exploration, model selection, data preparation, and model fine-tuning. Finally, we made predictions with our fine-tuned model on individual images and tested the model with an interactive, in-browser demo.

This hands-on tutorial underscored the practical applications of fine-tuning image classification models, especially when working with limited data and computational resources. The hand gesture recognizer you’ve built has many real-world applications, and you now have a solid foundation to tackle other image classification tasks.

If you’re intrigued by the underlying concepts leveraged in this tutorial and wish to deepen your understanding, I recommend fast.ai’s Practical Deep Learning for Coders course. By the end, you’ll thoroughly understand the model and training code and have the know-how to implement them from scratch.

While our tutorial concludes here, your journey in deep learning is far from over. In the upcoming tutorials, we’ll explore topics such as incorporating preprocessing and post-processing steps into the model, exporting the model to different formats for deployment, using the fine-tuned model to identify flawed training samples in our dataset, and building interactive in-browser demo projects similar to the one featured in this tutorial.

Once again, congratulations on your achievement, and keep learning!

Read more here: Source link