% PyTorch Neural Network: a Look at Life Expectancy with 95.86% Prediction Accuracy
➥ INTRODUCTION:
This notebook is a journey with PyTorch's neural network technology and a custom class linear regression model with the objective, or destination of this journey, of predicting life expectancy as accurately as possible using the quite thorough dataset from the World Health Organization with comprehensive data from 183 countries over 15 recent years.
(This notebook contains Jupyter notebook cells embedded throughout. You can scroll through each to investigate the code and results. If you would like the full interactive notebook, you can view it here. Or view the PDF here.)
➢ Dataset Information:
Dataset Source
Dataset Description from Source:
CONTEXT: Although there have been lot of studies undertaken in the past on factors affecting life expectancy considering demographic variables, income composition and mortality rates. It was found that affect of immunization and human development index was not taken into account in the past. Also, some of the past research was done considering multiple linear regression based on data set of one year for all the countries. Hence, this gives motivation to resolve both the factors stated previously by formulating a regression model based on mixed effects model and multiple linear regression while considering data from a period of 2000 to 2015 for all the countries. Important immunization like Hepatitis B, Polio and Diphtheria will also be considered. In a nutshell, this study will focus on immunization factors, mortality factors, economic factors, social factors and other health related factors as well. Since the observations this dataset are based on different countries, it will be easier for a country to determine the predicting factor which is contributing to lower value of life expectancy. This will help in suggesting a country which area should be given importance in order to efficiently improve the life expectancy of its population.
CONTENT: The project relies on accuracy of data. The Global Health Observatory (GHO) data repository under World Health Organization (WHO) keeps track of the health status as well as many other related factors for all countries The data-sets are made available to public for the purpose of health data analysis. The data-set related to life expectancy, health factors for 193 countries has been collected from the same WHO data repository website and its corresponding economic data was collected from United Nation website. Among all categories of health-related factors only those critical factors were chosen which are more representative. It has been observed that in the past 15 years , there has been a huge development in health sector resulting in improvement of human mortality rates especially in the developing nations in comparison to the past 30 years. Therefore, in this project we have considered data from year 2000-2015 for 193 countries for further analysis. The individual data files have been merged together into a single data-set. On initial visual inspection of the data showed some missing values. As the data-sets were from WHO, we found no evident errors. Missing data was handled in R software by using Missmap command. The result indicated that most of the missing data was for population, Hepatitis B and GDP. The missing data were from less known countries like Vanuatu, Tonga, Togo, Cabo Verde etc. Finding all data for these countries was difficult and hence, it was decided that we exclude these countries from the final model data-set. The final merged file(final dataset) consists of 22 Columns and 2938 rows which meant 20 predicting variables. All predicting variables was then divided into several broad categories:Immunization related factors, Mortality factors, Economical factors and Social factors.
⇾ Cleaning Column Names:
The data came in a little messy, including column names that just won't do. So let's clean those up immediately.
⇾ data_raw.describe()
Getting an idea of the data we are working with and the distribution.
⇾ Separating the Data:
Here, I am accounting for the different types of data and separating the numerical data from the categorical. In this dataset, there is only one categorical column that will need encoding, the "status" column, developed or developing. The "country" column will be useful in plotting and visualizing the data, but it will be a part of the inputs to the model. So we do not need to be concerned with it.
⇾ Removing NULL
Values:
Because there are quite a few NULL
values, I will be replacing them with the mean. This will help our model handle the data better and be more accurate.
➢ Correlation:
Here is our first look at how variables correlate on the raw data. Our target vector to the model will be the life expectancy column. Here we can see some interesting patterns. There are quite a few features that seem to directly correlate with the life expectancy to varying degrees. We will see those more specifically below.
Just past the heatmap, there are individual plots of each of the most impactful features and how they correlate to life expectancy, some more clearly than others.
⇾ Dividing datasets into training and validation sets:
⇾ Establishing batches for training:
➢ Auditioning the best loss functions for this data:
⇾ Preliminary Model
⇾ Loss Function Auditioning Function:
I wrote the following function so that I can try out a variety of loss functions available with PyTorch to see which is best for this dataset and model. I have narrowed them all down to the following two, with which I will use a list of learning rates to train theses hyperparameters. Spoiler Alert! The nn.L1Loss
definitely wins, although the nn.MSELoss` was not bad.
⇾ nn.L1Loss()
- Audition for the role of loss function
⇾ nn.MSELoss()
- Audition for the role of loss function
➢ Defining Custom Class Model and Evaluation Functions:
➢ Evaluate Predictions:
Here, I am using single samples from the evaluation data that I separated out from the raw data at the beginning of the notebook. Let's have a look at the data we are working with. It looks like the random sample gave us a good selection of data for testing and evaluating the model.
The following is setting up our evaluation data so that we can see how well our model performs on real-world data that it has never seen before.
Here we see that we have 138 samples to work with for evaluation and testing. The function below will compare the predictions from our trained model with each sample and print out the target value, the prediction, the difference between the two, and the percentage accuracy of the model. At the end of the run, it will average the accuracy and give a final percentage.
⇾ Determining Accuracy: (function)
➢ Accuracy of 95.86%
So we can see here that the model did very well on predicting life expectancy using the various features in the dataset. We will now look at the samples that tested below the 95.86% average and see what it is about those samples that threw off our model. Perhaps it will shed some light on what might lead to more accuracy in the future.
➢ Review evaluation samples that score below average accuracy:
Not only did we get a really nice accuracy score of 95.86%. Even the test samples that came in below average have a pretty high level of accuracy, as you can see from the plot above. The vast majority are still in the high 90 percents!
⇾ 10 Examples from the Below Average Samples:
Let's look at the samples that came in below the accuracy average.
⇾ Overview of the Below Average Samples:
As you can see below, the average accuracy score for the 43 samples that came in below our average is still 90%. It looks like we have a pretty good model here!
⇾ Below Average Samples by Accuracy Ascending:
The graphs plotted below show us the most notable features and how they correspond to the accuracy of the model in predicting the life expectancy in samples that scored less than our average accuracy of 95.31%. As you can see from the excerpt from our below average samples above, most of the countries that the model had a slightly tougher time predicting are developing countries with heavy outliers.
📊 Plotting Below Average Accuracy Samples:
➢ Conclusion:
I honestly feel like I have only scratched the surface of what can be taken away from this dataset, but it was a very good start and a very successful implementation of a PyTorch linear regression neural network. I hope you enjoyed this journey with me!