top of page
  • joydeepml2020

Distracted Driver Detecton using CNN

Driver Distraction is very common phenomenon which cause accidents.Distracted driving plays part in up to 25% of crashes.

As per the WHO report, the total number of mortalities increases each year, and the most common cause is vehicle driver distraction. Recent study recently done by recently published a study Distracted Driving in India The major cause of driver distraction are usage of cell phones,texting while driving and also congintive behaviours of the drivers. In this study, we are develping a algorithms which can detect the following behaviour.

  1. texting - right

  2. talking on the phone - right

  3. texting - left

  4. talking on the phone - left

  5. operating the radio

  6. drinking

  7. reaching behind

  8. hair and makeup

  9. talking to co- passenger

Posing the computer Vision Problem:

This problem can posed as a supervised classification problem. We will use the State Farm Distracted Driver Detection data set from kaggle to train a computer vision algorithim which will detect the disctrated driving behaviour.Using CNN based architecture, we will try to detect the in apropriate behaviour of the drivers. The Goals of the system will be to predict the likelihood of dirver's behaviour of the drivers given a driver's driving image.

About the dataset:

As discussed, I am going to use the State Farm Distracted Driver Detection from kaggle. The link given below:

The dataset contains images of different driving behaviour classified into the following categories:

The 10 classes to predict are: c0: safe driving

c1: texting - right

c2: talking on the phone - right

c3: texting - left

c4: talking on the phone - left

c5: operating the radio

c6: drinking

c7: reaching behind

c8: hair and makeup

c9: talking to passenger

The aim of the system will be predicting the above mentioned activities of the driver.

Creating a dataset:

For saving on the compute storage, I am only going to use the training data. I have downloaded the dataset into my google drive and fetching the training images from the google drive.

Loading the data, my training dataset is in the following folder "/content/drive/MyDrive/Driver_distraction/imgs/train". Now we have the data, the next step is to explore the images and understand the dataset. We will visualise the dataset and its each classes programatically.

Project Outline:

The following steps will be performed :

  1. Loading the images of the training data

  2. Understand the distribution of the each class

  3. Visualising the dataset

  4. Build the model

  • Create the base line model

    • Experiment with transfer learning

    • Select the best model

    • Save the best performing model

    • Evaluating both models and its predictions of each class

Importing the libraries:

#importing the necessary libraries 
import numpy as np
import os
import tensorflow as tf
# Creating the data directory path
data_dir = pathlib.Path(data_dir)
# How many images are there in the training folder 
image_count = len(list(data_dir.glob('*/*.jpg')))print(image_count)


There are 22434 images for training.

Exploring the folder Structure: We are going to use os.walk method to understand the folder structure.

import os

for dirpath,dirfolder,filenames in \n os.walk("Driver_distraction/imgs"):
    print(f"There are {len(dirfolder)} directories and \n {len(filenames)} files in the path{dirpath}")

There are 1 directories and 0 files in the path/content/drive/MyDrive/Driver_distraction/imgs There are 10 directories and 0 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train There are 0 directories and 2326 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train/c4 There are 0 directories and 2489 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train/c0 There are 0 directories and 2002 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train/c7 There are 0 directories and 2327 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train/c2 There are 0 directories and 2267 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train/c1 There are 0 directories and 2325 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train/c6 There are 0 directories and 2346 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train/c3 There are 0 directories and 2312 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train/c5 There are 0 directories and 1911 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train/c8 There are 0 directories and 2129 files in the path/content/drive/MyDrive/Driver_distraction/imgs/train/c9

From the above output, we will extract the structure of the training data. We understand that there are 10 folders(c0...c10) and now we will extract the numbers of files in each folders. This will help us to understand whether the dataset is imbalanced or not.

def file_count(directory_path):
"""  This function will traverse through the directory and will return two list of folder name and number of files in each folders respectively"""
class_label_list = []
num_files = []
for directory_path,dir_folder,filenames  in  \n           
 directory_path = directory_path.split('/')
 return class_label_list,num_files

Output: (['', 'c4', 'c0', 'c7', 'c2', 'c1', 'c6', 'c3', 'c5', 'c8', 'c9'], [0, 2326, 2489, 2002, 2327, 2267, 2325, 2346, 2312, 1911, 2129]) # As os.walk traverse from root and we are not accounting the parent folder, we will use the list from index one class_label = class_label[1:] number_file = number_file[1:]

We are jumping few steps in the blog to make the post shorter in length. For every steps please refer to the github

After plotting the class labels and no of files respectively, we got the following bar chart

We can easily understand that there are 10 different folders as c0,c1,c2 etc and each represents a perticular class of images.

Visualising the dataset for training

We will create a function which will take the path of the target folder and class name and going to display the picture. It will choose an image randomly, display that with shape

class_names = np.array([sorted([ for items in data_dir.glob("*")])])class_names

array([['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']], dtype='<U2')

import randomimport matplotlib.image as mpimg# Setup the target directory from where we will load the image
def visualize_training_image(target_dir,target_class):"""    """
target_folder = target_dir + target_class
# Getting the random image path
random_images = random.sample(os.listdir(target_folder),1)
#read in the image and plot it using matplotlib
img = mpimg.imread(target_folder + "/"+ random_images[0])plt.imshow(img)plt.title(target_class)plt.axis("off")
print(f"The shape of the image is {img.shape}")return img

#View a random image from a training datasetimg = visualize_training_image(target_dir= "/content/drive/MyDrive/Driver_distraction/imgs/train/",target_class="c0")

We can use the above function to visualise images from each class randomly.

Finding out the class names and also printing out the preprocessed images:

import matplotlib.pyplot as pltplt.figure(figsize=(9, 9))for images, labels in train_ds.take(1):for i in range(9):
ax = plt.subplot(3, 3, i + 1)

Creating a the model with CNN and Max pooling layer

In this section, we will first load the images from the directory and also going to perform the following steps

  1. Resize the image to 180 X 180 (image height X image width)

  2. Create the split of 80:20 i.e 80 % traning data and 20 % test data

  3. Create patches of data for training.

We can load the image data from directory using the tf.keras.utils.image_dataset_from_directory utility.

A very high label architecture of our model will be like this:

The step by step codes are available at my github. After the training the CNN model with 10 epochs we got an training accuracy: 0.9982 and validation accuracy :0.9926. Below is the model summary of the trained model.We used sequential api of tensorflow 2.0+.

Evaluating the model:

There are number of ways to evaluate the model. We used accuracy during training. But as it is multiclass classification we can use other classification metrics like precission, recall and F1score also. For the time being, we are going to plot the training curves.

We have our loss curv merging towards downward direction and also accuracy for both training and validation are going up. Visualizing both the curves in one graph is bit challenging.In the next section, we are going to split the curves.

We are splitting the loss and accuracy curve into two seperate plots. The below helper function is going to help us with that.

def plot_loss_curve(history):"""  This function will return two plot- accuracy plot for traning and validation   and also   -loss curves for training and validation """
training_loss = history.history["loss"]
validation_loss = history.history["val_loss"]
training_accuracy = history.history["accuracy"]
validation_accuracy = history.history["val_accuracy"]
epochs = range(len(history.history["loss"])) # How many epochs the model is trained for# Plot the following curves # Plot loss plt.plot(epochs,training_loss,label="training_loss")
plt.legend()# Plot accuracy plt.figure()

By using this plot, we ploted these two graphs.

We also ploted the confusion matrix after evaluating the predicted values on the validation data set. Below is the confusion matrix:

We are using our helper function to plot the confusion matrix.

From the confusion matrix we can easily identify where our model is doing wrong.

Further improve to this project can done by

  • Using test data and making prediction on the real custom images and analysing F1 scores of each class.

  • Using trasfer learning

  • Indentifyng potential data leak in the dataset.

  • Using custom images to make prediction on the

61 views0 comments

Recent Posts

See All

Introduction to transformer- part II

In the last article, we introducted the advatnges of the transformer based models and also understood different components of a transformer based models from a black box perspective. In this article,


bottom of page