Common Machine Learning Obstacles
In this blog, Seth DeLand of MathWorks discusses two of the most common obstacles relate to choosing the right classification model and eliminating data overfitting.
By Seth DeLand, Product Marketing Manager, Data Analytics, MathWorks
Engineers and scientists who are modeling with machine learning often face challenges when working with data. Two of the most common obstacles relate to choosing the right classification model and eliminating data overfitting.
Classification models assign items to a discrete group or class based on a specific set of features.Determining the best classification model often presents difficulties given the uniqueness of each dataset and desired outcome. Overfitting occurs when the model is too closely aligned with limited training data that may contain noise or errors. An overfit model is not able to generalize well to data outside the training set, limiting its usefulness in a production system.
By integrating scalable software tools and machine learning techniques, engineers and scientists can identify the best model and protect against overfitting.
Choosing the Classification Model
Classification model types can be challenging because each model type has its own characteristic, which could be a strength or weakness depending on the problem.
For starters, you must answer a few questions about the type and purpose of data:
- What is the model meant to accomplish?
- How much data is there, and what type of data is it?
- How much detail is needed? Is storage a limiting factor?
Answering these questions can help narrow the choices and select the correct classification model. Engineers and scientists can use cross-validation to test how accurately a model will evaluate data. After cross-validation, you can select the best-fitting classification model.
There are many types of classification models, here are five common types:
- Logistic Regression: This model is often used as a baseline due to its simplicity. It is used for problems where there are two possible classes that data may be categorized into. A logistic regression model returns probabilities for how likely a data point belongs to each class.
- k-nearest neighbor (kNN): This simple yet effective way of classification categorizes data points based on their distance to other points in a training dataset. The training time of kNN is short, but this model can confuse irrelevant attributes for important ones unless weights are applied to the data, especially as the number of data points grows.
- Decision Trees: These models predict responses visually, and it’s relatively easy to follow the decision path taken from root to leaf. This type of modelis especially useful when it’s important to showhow the conclusion was reached.
- Support Vector Machine (SVM): This model uses a hyperplane to separate data into two or more classes. It is accurate, tends not to overfit, and is relatively easy to interpret, but training time can be on the longer side especially for larger datasets.
- Artificial neural networks(ANNs): These networks can be configured and trained to solve a variety of different problems including classification and time series prediction. However, the trained models are known to be difficult to interpret.
Engineers and scientists can simplify the decision-making process by using scalable software tools to determine which model best fits a set of features, assess classifier performance, compare and improve model accuracy, and finally, export the best model. These tools also help users explore the data, select features, specify validation schemes, and train multiple models.
Eliminating Data Overfitting
Overfitting occurs when a model fits a particular dataset but does not generalize well to new data. Overfitting is typically hard to avoid because it is often the result of insufficient training data, especially when the person responsible for the model did not gather the data. The best way to avoid overfitting is by using enough training data to accurately reflect the model’s diversity and complexity.
Data regularization and generalization are two additional methods engineers and scientists can apply to check for overfitting. Regularization is a technique that prevents the model from over-relying on individual data points. Regularization algorithms introduce additional information into the model and handle multicollinearity and redundant predictors by making the model more parsimonious and accurate.These algorithms typically work by applying a penalty for complexity, such as adding the coefficients of the model into the minimization or including a roughness penalty.
Generalization divides available data into three subsets. The first set is the training set, and the second set is the validation set. The error on the validation set is monitored during the training process, and the model is fine-tuned until accurate. The third subset is the test set, which is used on the fully trained classifier after the training and cross-validation phases to test that the model hasn’t overfit the training and validation data.
There are six cross-validation methods that can help prevent overfitting:
- k-fold:Partitions data into k randomly chosen subsets (or folds) of roughly equal size, with one used to validate the model trained with the remaining subsets. This process is repeated k times, as each subset is used exactly once for validation.
- Holdout: Separates data into two subsets of specified ratio for training and validation.
- Leave one out: Partitions data using the k-fold approach, where k equals the total number of observations in the data.
- Repeated random subsampling: Performs Monte Carlo repetitions of randomly separating data and aggregates results over all the runs.
- Stratify: Partitions data so both training and test sets have roughly the same class proportions in the response or target.
- Resubstitution: Uses the training data for validation without separating it. This method often produces overly optimistic estimates for performance and must be avoided if there is sufficient data.
Machine learning veterans and beginners alike run into trouble with classification and overfitting. While the challenges surrounding machine learning can seem daunting, leveraging the right tools and utilizing the validation methods covered here will help engineers and scientists apply machine learning more easily to real-world projects.
To learn more about classification modeling and overfitting and how MATLAB is helping to overcome this machine learning challenge, see the links below or email me at sdeland@mathworks.com.
- Test-Drive the Classification Learner App: Create and train a machine learning model by running MATLAB right in your browser.
- Classification Examples: Check out these 25 classification examples including discriminant analysis and NaïveBayes.
- Improve Shallow Neural Network Generalization and Avoid Overfitting: Here is an example of overfitting.