Which machine learning classifier to choose, in general? [closed]
Want to improve this question? Update the question so it can be answered with facts and citations by editing this post.
Closed 3 years ago.
Improve this questionSuppose I'm working on some classification problem. (Fraud detection and comment spam are two problems I'm working on right now, but I'm curious about any classification task in general.)
How do I know which classifier I should use?
- Decision tree
- SVM
- Bayesian
- Neural network
- K-nearest neighbors
- Q-learning
- Genetic algorithm
- Markov decision processes
- Convolutional neural networks
- Linear regression or log开发者_Go百科istic regression
- Boosting, bagging, ensambling
- Random hill climbing or simulated annealing
- ...
In which cases is one of these the "natural" first choice, and what are the principles for choosing that one?
Examples of the type of answers I'm looking for (from Manning et al.'s Introduction to Information Retrieval book):
a. If your data is labeled, but you only have a limited amount, you should use a classifier with high bias (for example, Naive Bayes).
I'm guessing this is because a higher-bias classifier will have lower variance, which is good because of the small amount of data.
b. If you have a ton of data, then the classifier doesn't really matter so much, so you should probably just choose a classifier with good scalability.
What are other guidelines? Even answers like "if you'll have to explain your model to some upper management person, then maybe you should use a decision tree, since the decision rules are fairly transparent" are good. I care less about implementation/library issues, though.
Also, for a somewhat separate question, besides standard Bayesian classifiers, are there 'standard state-of-the-art' methods for comment spam detection (as opposed to email spam)?
First of all, you need to identify your problem. It depends upon what kind of data you have and what your desired task is.
If you are
Predicting Category
:
- You have
Labeled Data
- You need to follow
Classification Approach
and its algorithms- You don't have
Labeled Data
- You need to go for
Clustering Approach
If you are
Predicting Quantity
:
- You need to go for
Regression Approach
Otherwise
- You can go for
Dimensionality Reduction Approach
There are different algorithms within each approach mentioned above. The choice of a particular algorithm depends upon the size of the dataset.
Source: http://scikit-learn.org/stable/tutorial/machine_learning_map/
Model selection using cross validation may be what you need.
Cross validation
What you do is simply to split your dataset into k non-overlapping subsets (folds), train a model using k-1 folds and predict its performance using the fold you left out. This you do for each possible combination of folds (first leave 1st fold out, then 2nd, ... , then kth, and train with the remaining folds). After finishing, you estimate the mean performance of all folds (maybe also the variance/standard deviation of the performance).
How to choose the parameter k depends on the time you have. Usual values for k are 3, 5, 10 or even N, where N is the size of your data (that's the same as leave-one-out cross validation). I prefer 5 or 10.
Model selection
Let's say you have 5 methods (ANN, SVM, KNN, etc) and 10 parameter combinations for each method (depending on the method). You simply have to run cross validation for each method and parameter combination (5 * 10 = 50) and select the best model, method and parameters. Then you re-train with the best method and parameters on all your data and you have your final model.
There are some more things to say. If, for example, you use a lot of methods and parameter combinations for each, it's very likely you will overfit. In cases like these, you have to use nested cross validation.
Nested cross validation
In nested cross validation, you perform cross validation on the model selection algorithm.
Again, you first split your data into k folds. After each step, you choose k-1 as your training data and the remaining one as your test data. Then you run model selection (the procedure I explained above) for each possible combination of those k folds. After finishing this, you will have k models, one for each combination of folds. After that, you test each model with the remaining test data and choose the best one. Again, after having the last model you train a new one with the same method and parameters on all the data you have. That's your final model.
Of course, there are many variations of these methods and other things I didn't mention. If you need more information about these look for some publications about these topics.
The book "OpenCV" has a great two pages on this on pages 462-463. Searching the Amazon preview for the word "discriminative" (probably google books also) will let you see the pages in question. These two pages are the greatest gem I have found in this book.
In short:
Boosting - often effective when a large amount of training data is available.
Random trees - often very effective and can also perform regression.
K-nearest neighbors - simplest thing you can do, often effective but slow and requires lots of memory.
Neural networks - Slow to train but very fast to run, still optimal performer for letter recognition.
SVM - Among the best with limited data, but losing against boosting or random trees only when large data sets are available.
Things you might consider in choosing which algorithm to use would include:
Do you need to train incrementally (as opposed to batched)?
If you need to update your classifier with new data frequently (or you have tons of data), you'll probably want to use Bayesian. Neural nets and SVM need to work on the training data in one go.
Is your data composed of categorical only, or numeric only, or both?
I think Bayesian works best with categorical/binomial data. Decision trees can't predict numerical values.
Does you or your audience need to understand how the classifier works?
Use Bayesian or decision trees, since these can be easily explained to most people. Neural networks and SVM are "black boxes" in the sense that you can't really see how they are classifying data.
How much classification speed do you need?
SVM's are fast when it comes to classifying since they only need to determine which side of the "line" your data is on. Decision trees can be slow especially when they're complex (e.g. lots of branches).
Complexity.
Neural nets and SVMs can handle complex non-linear classification.
As Prof Andrew Ng often states: always begin by implementing a rough, dirty algorithm, and then iteratively refine it.
For classification, Naive Bayes is a good starter, as it has good performances, is highly scalable and can adapt to almost any kind of classification task. Also 1NN (K-Nearest Neighbours with only 1 neighbour) is a no-hassle best fit algorithm (because the data will be the model, and thus you don't have to care about the dimensionality fit of your decision boundary), the only issue is the computation cost (quadratic because you need to compute the distance matrix, so it may not be a good fit for high dimensional data).
Another good starter algorithm is the Random Forests (composed of decision trees), this is highly scalable to any number of dimensions and has generally quite acceptable performances. Then finally, there are genetic algorithms, which scale admirably well to any dimension and any data with minimal knowledge of the data itself, with the most minimal and simplest implementation being the microbial genetic algorithm (only one line of C code! by Inman Harvey in 1996), and one of the most complex being CMA-ES and MOGA/e-MOEA.
And remember that, often, you can't really know what will work best on your data before you try the algorithms for real.
As a side-note, if you want a theoretical framework to test your hypothesis and algorithms theoretical performances for a given problem, you can use the PAC (Probably approximately correct) learning framework (beware: it's very abstract and complex!), but to summary, the gist of PAC learning says that you should use the less complex, but complex enough (complexity being the maximum dimensionality that the algo can fit) algorithm that can fit your data. In other words, use the Occam's razor.
Sam Roweis used to say that you should try naive Bayes, logistic regression, k-nearest neighbour and Fisher's linear discriminant before anything else.
My take on it is that you always run the basic classifiers first to get some sense of your data. More often than not (in my experience at least) they've been good enough.
So, if you have supervised data, train a Naive Bayes classifier. If you have unsupervised data, you can try k-means clustering.
Another resource is one of the lecture videos of the series of videos Stanford Machine Learning, which I watched a while back. In video 4 or 5, I think, the lecturer discusses some generally accepted conventions when training classifiers, advantages/tradeoffs, etc.
You should always keep into account the inference vs prediction trade-off.
If you want to understand the complex relationship that is occurring in your data then you should go with a rich inference algorithm (e.g. linear regression or lasso). On the other hand, if you are only interested in the result you can go with high dimensional and more complex (but less interpretable) algorithms, like neural networks.
Selection of Algorithm is depending upon the scenario and the type and size of data set. There are many other factors.
This is a brief cheat sheet for basic machine learning.
精彩评论