Using Gradient Boosting (with Early Stopping)

  |   Source


Whenever training a machine learning algorithm, you need to balance between learning from data and overfitting data. Overfitting means that the algorithm will work well on the training data, but will not generalize well to new data. For many algorithms, the standard method to avoid this problem is to use regularization. For gradient boosting, the most straightforward way is to use cross validation in order to select the learning rate and the number of trees to train.

In this post, I will go over the most important gradient boosting parameters, and describe how to implement a technique called early stopping, in order to avoid overfitting, using scikit-learn.

Gradient Boosting Hyperparameters

There are 4 important parameters to tune in gradient boosted trees:

  • Number of trees:

    In an ensemble, in order to calculate a prediction, the prediction of each tree in the ensemble is calculated, then all of the predictions are combined together. In general, you want the number of trees to be as high as possible (>1000) without overfitting. However having it too high could lead to overfitting, which the other parameters, as well as early stopping, will help avoid.

  • Learning rate

    The learning rate scales how much each tree's effect changes the overall prediction. So, with a learning rate of 0.1, the Nth tree's prediction will be scaled by 0.1, resulting in smaller steps in the gradient direction. Big steps will mean faster convergence, but the boosting may not result in the best optimum. You generally want this to be small(<0.01).

  • Maximum tree depth & minimum samples per leaf

    The maximum tree depth and the minimum samples per leaf are closely linked: the goal is to avoid trees that are too deep, fitting each sample on their own leaf, while at the same time, having some depth in order to learn complex interations between features. Good values for the maximum tree depth are generally between 1-10, and the minimum samples per leaf depends on how big the training size is (bigger training set -> larger minimum - higher maximum tree depth -> lower minimum).

Some other potentially useful parameters are subsampling - training trees on a random subset of the data - and randomly selecting a subset of features to use for the tree splits.

Early Stopping

A typical training curve looks like:


No matter how well tuned all of these parameters are, you still need to select the best number of trees (the one that minimizes the validation error). The idea behind early stopping is to stop training when the model begins to overfit. One obvious way is to stop when the validation set error begins to increase. This would work well if the above curve were true, but in a real curve looks more like this:


In that case, the naive method would stop training at around 50 iterations, missing the further minimum.

There are several methods to detect when overfitting begins. I will describe two of them here, but for more details you can read the article by Lutz Prechelt: Early stopping - but when? (Warning: PDF).

Generalizing loss

This criteria looks at the relative increase of the validation error, over the minimum validation error seen so far.


We stop training as soon as the generalization loss exceeds a certain threshold.

Consecutive decreases

This criteria looks at whether or not the validation error increases. We stop training as soon as the number of successive increases exceeds a certain threshold. The naive condition described above corresponds to this criteria with a treshold of 1.

Scikit-learn Implementation

In the scikit-learn gradient boosting class, GradientBoostingClassifier or GradientBoostingRegressor, the fit function takes an optional monitor argument that can be used to implement early stopping. The monitor is called after each tree is fit, and when it returns True, the training stops.

Here is a way to implement it for the consecutive decrease criteria.

from sklearn.ensemble._gradient_boosting import predict_stage

class Monitor():
    """Monitor for early stopping in Gradient Boosting for classification.

    The monitor checks the validation loss between each training stage. When
    too many successive stages have increased the loss, the monitor will return
    true, stopping the training early.

    X_valid : array-like, shape = [n_samples, n_features]
      Training vectors, where n_samples is the number of samples
      and n_features is the number of features.
    y_valid : array-like, shape = [n_samples]
      Target values (integers in classification, real numbers in
      For classification, labels must correspond to classes.
    max_consecutive_decreases : int, optional (default=5)
      Early stopping criteria: when the number of consecutive iterations that
      result in a worse performance on the validation set exceeds this value,
      the training stops.

    def __init__(self, X_valid, y_valid, max_consecutive_decreases=5):
        self.X_valid = X_valid
        self.y_valid = y_valid
        self.max_consecutive_decreases = max_consecutive_decreases
        self.losses = []

    def __call__(self, i, clf, args):
        if i == 0:
            self.consecutive_decreases_ = 0
            self.predictions = clf._init_decision_function(self.X_valid)

        predict_stage(clf.estimators_, i, self.X_valid, clf.learning_rate,
        self.losses.append(clf.loss_(self.y_valid, self.predictions))

        if len(self.losses) >= 2 and self.losses[-1] > self.losses[-2]:
            self.consecutive_decreases_ += 1
            self.consecutive_decreases_ = 0

        if self.consecutive_decreases_ >= self.max_consecutive_decreases:
            print("Too many consecutive decreases of loss on validation set"
                  "({}): stopping early at iteration {}.".format(self.consecutive_decreases_, i))
            return True
            return False

We're using the predict_stage function, which updates the prediction for the last tree that was fit. This class can easily be modified for another stopping criteria (such as the generalizing loss) by modifying the condition to return True. Now, to use this, you just need to create an instance of the class and pass it when calling fit.