Alink's Random Talk: the overall design of online learning algorithm FTRL

Alink ramble (12): overall design of online learning algorithm FTRL


0x00 summary

Alink is a new generation of machine learning algorithm platform developed by Alibaba based on Flink, a real-time computing engine. It is the industry's first machine learning platform that supports both batch and streaming algorithms. This article and the following will introduce how the online learning algorithm FTRL is implemented in alink, hoping to help you.

0x01 concept

Because Alink implements LR + FTRL, we need to start with the logical regression LR.

1.1 logistic regression

Although Logistic Regression is called regression, it is actually a classification model and is often used for binary classification. The essence of Logistic Regression is to assume that the data obey this distribution, and then use maximum likelihood estimation to estimate the parameters.

The idea of logistic regression is to first fit the decision boundary (not limited to linear, but also polynomial), and then establish the probability relationship between this boundary and classification, so as to obtain the probability in the case of two classification.

1.1.1 derivation process

Let's start with linear regression. In some cases, it is not feasible to use a linear function to fit the law and then take the threshold value. The reason why it is not feasible is that the fitted function is too straight and the outliers (also known as outliers) have too much influence on the results. But our overall thinking is not wrong. What is wrong is to use the fitting function that is too "straight". If the function we use to fit is nonlinear and not so straight, would it be better?

So let's do two things:

  • Find a way to solve the problem that the regression function is seriously affected by outliers
  • Select a threshold

For the first thing, we use the sigmod function to bend the regression function.

For binary classification problems, 1 represents positive cases and 0 represents negative cases. Logistic regression is to find a hypothetical function h based on the actual value predicted by the output of linear function_ θ (x) = g( θ, x) , mapping the actual value to 0, 1. In logistic regression, the logarithmic probability function is selected as the activation function. The logarithmic probability function is an important representative of Sigmoid function (S-shaped function).

For the second thing, we chose a threshold of 0.5.

That is, when I select the threshold value as 0.5, the value less than 0.5 must be negative, even if it is 0.49. Is it correct to judge a sample as negative? In fact, it is not necessarily because it still has a 49% probability of being a positive example. However, even if the probability that it is a positive case is 0.1, if we randomly select 1w samples to make predictions, there will still be nearly 100 predictions that it is a negative case and the result is actually a positive case. No matter how we choose, the error exists, so we choose the threshold to select the acceptable error.

1.1.2 solution

So far, we have basically understood the origin of logistic regression. We know that the discriminant function of logistic regression is

\[h(z) = \frac{1}{1+e^{-x}},z = W^TX \]

How to solve logistic regression? That is, how to find a group of W that can make h(z) all predict the most correct probability.

There are many methods to solve logistic regression. Here we mainly talk about gradient descent and Newton method. The main objective of optimization is to find a direction in which the value of the loss function can be reduced. This direction is often obtained by various combinations of first-order partial derivatives or second-order partial derivatives.

Gradient descent is to find the descent direction through the first derivative of J(w) to w, and update the parameters iteratively.

The basic idea of Newton's method is to make a second-order Taylor expansion of J(w) near the existing minimum point estimate, and then find the next minimum point estimate.

1.1.3 random gradient descent

When N in the sample data is large, the random gradient descent method is usually used. The algorithm is as follows:

while {
    for i in range(0,m):
        w_j = w_j + a * g_j

The advantage of random gradient descent is that distributed parallelization can be realized. The specific calculation process is as follows:

  1. At each iteration, a certain proportion of samples are randomly sampled as the calculation samples of the current iteration.
  2. For each sample in the calculation sample, the calculation gradients of different characteristics are calculated respectively.
  3. Through the aggregation function, the feature gradients of all calculated samples are accumulated to obtain the cumulative gradient and loss of each feature.
  4. Finally, the parameters are updated according to the latest gradient and previous parameters.
  5. Calculate the loss function error value according to the updated parameters. If the loss function error value reaches the allowable range, stop the iteration, otherwise repeat step 1

1.2 LR parallel computing

From the solution methods of logistic regression, we can find that these algorithms need to calculate the gradient, so the most important thing for the parallelization of logistic regression is to parallelize the gradient calculation of the objective function.

We can see that only point multiplication and addition between vectors are required in the gradient vector calculation of the objective function. It is easy to split each iteration process into independent calculation steps, calculate independently by different nodes, and then merge the calculation results.

Therefore, parallel LR actually parallelizes the gradient direction calculation in finding the descending direction of the loss function in the process of solving the optimal solution of the loss function, and parallelization can also be used in the process of determining the descending direction by using the gradient.

If the sample matrix is divided into rows, the sample eigenvectors are distributed to different calculation nodes, and each calculation node completes the point multiplication and sum calculation of the samples it is responsible for, and then the calculation results are merged, the "LR parallel by row" is realized.

LR parallel by row solves the problem of sample size. However, in practice, there will be scenarios where high-dimensional feature vectors are logically regressed (for example, the feature dimensions in the advertising system are up to hundreds of millions). Parallel processing only by row cannot meet the needs of such scenarios. Therefore, it is also necessary to split the high-dimensional feature vectors into several small vectors by column for solution.

1.3 traditional machine learning

The traditional machine learning development process basically consists of the following steps:

  1. Data fusion, acquisition of training and evaluation data sets.
  2. Feature engineering.
  3. Build models, such as LR, FM, etc.
  4. Train the model to obtain the optimal solution.
  5. Evaluate the effect of the model.
  6. Save the model and use the effective model for training online.

There are two main bottlenecks in this approach:

  1. The model update cycle is slow and can not effectively reflect online changes. The fastest hourly level is generally the day level or even the week level.
  2. The model parameters are few and the prediction effect is poor; The memory required for multi line prediction of model parameters is large, which cannot be guaranteed by QPS.

For example, in the traditional Batch algorithm, all training data sets are calculated for each iteration (for example, calculating the global gradient). The advantage is that the accuracy and convergence are OK. The disadvantage is that it can not effectively process large data sets (at this time, the global gradient calculation cost is too high), and it can not be applied to data flow for online learning.

Generally speaking, there are two solutions to these problems:

  • On line learning algorithm is adopted for 1.
  • For 2, some optimization methods are used to obtain the sparse solution as much as possible on the premise of ensuring the accuracy, so as to reduce the number of model parameters.

1.4 online learning

OnlineLearningOnlineLearning represents a series of machine learning algorithms. It is characterized by that each sample can be trained and the model can be adjusted in real time and quickly according to the online feedback data, so that the model can reflect the online changes in time and improve the accuracy of online prediction.

The traditional training method is generally static after the model training goes online, and will not have any interaction with the online conditions. If prediction errors are added, they can only be corrected at the next update, but the update time is generally long.

Online Learning has different training methods. It will dynamically adjust the model according to the online prediction results, add model prediction errors, and make corrections in time. Therefore, Online Learning can respond to online changes in a more timely manner.

The optimization goal of Online Learning is to minimize the overall loss function. It needs to quickly solve the optimal solution of the objective function.

The characteristic of online learning algorithm is that each training sample will be used to iterate the model once with the loss and gradient generated by the sample, and the model will be trained one by one. Therefore, it can handle large amount of data training and online training. Online gradient descent (OGD) and random gradient descent (SGD) are commonly used. The essential idea is to perform gradient descent on the loss function L (w, zi) of the single data not added in the above [problem description], because the direction of each step is not globally optimal, so the overall presentation will be a seemingly random descent path.

1.5 FTRL

FTR is the predecessor of FTRL. The idea is to find a parameter that minimizes the sum of the loss functions of all previous samples each time.

FTRL, i.e. follow the regulated leader, is generated from previous work. The main starting point is to improve the sparsity and meet the accuracy requirements. FTRL adds regularization to the optimization objective of FTL to prevent over fitting.

The loss function of FTRL is generally not easy to solve. In this case, it is generally necessary to find an agent's loss function.

The agent loss function needs to meet the following conditions:

  1. The agent loss function is easy to solve, and it is better to have an analytical solution.
  2. The smaller the difference between the solution of the proxy loss function and the solution of the original function, the better

In order to measure the difference between the two solutions in condition 2, the concept of regret is introduced.

1.5.1 regret & sparsity

Generally, for online learning, we are committed to solving two problems: reducing regret and improving sparsity. Thereinto, regret is defined as:

\[Regret=\sum_{t=1}^Tl_t(w_t)−min_w\sum_{t=1}^Tl_t(w) \]

Where T represents the t-th iteration in the total t-rounds, ℓ T represents the loss function, and w represents the parameters to be learned. Degret represents the loss gap between the "solution obtained by proxy function" and the "solution obtained by real loss function".

The second term indicates that the optimal solution of the loss function after all samples are obtained. Because online learning can only update parameters according to a few samples at a time, which is highly random, a robust optimization method is required. Regret literally means "regret", which means no regret after updating.

It can be proved theoretically that if an online learning algorithm can guarantee that its regret is a sublinear function of t, then:

\[\lim_{t→∞}\frac{Regret(t)}{t}=0 \]

With the increase of training samples, the online learning model is infinitely close to the optimal model. That is, with the increase of training samples, the difference between the actual loss values of the parameters obtained from the agent loss function and the original loss function becomes smaller and smaller. Not surprisingly, FTRL meets this characteristic.

On the other hand, sparsity, that is, the sparsity of the model, is also valued in reality. Hundreds of millions of features are not uncommon. The more complex the model is, the more storage and time resources it requires. Sparse models will greatly reduce the memory and complexity of prediction. In addition, the sparse model is relatively interpretable, which is the advantage of L1 regularization.

1.5.2 pseudo code of ftrl

Per coordinate means that FTRL trains and updates w separately for each dimension. Each dimension uses different learning rates, which is the one before lamd2 in the above code. Compared with the uniform learning rate used by all feature dimensions of W, this method takes into account the uneven distribution of training samples on different features. If there are few training samples containing a feature of W, and each sample is precious, the training rate corresponding to the feature dimension can be maintained at a relatively large value alone. Each sample containing the feature can make a big step forward in the gradient of the sample, It does not need to be forcibly consistent with the progress of other feature dimensions.

1.5.3 brief understanding

Let's take a look at the update formula of feature weight at the next moment to increase understanding (I personally think the explanation found is relatively easy to understand):

In the formula, the first term is an estimate of the contribution to the loss function, the second term is to control that w (i.e. model) does not change too much in each iteration, and the third term represents L1 regularity (obtaining sparse solution).

0x02 example code

We use the official Alink sample code. We can see that it is roughly divided into several parts:

As you can see, I have done a lot of work to analyze FTRL

public class FTRLExample {

    public static void main(String[] args) throws Exception {
        // setup feature engineering pipeline
        Pipeline featurePipeline = new Pipeline()
                        new StandardScaler() // Standard zoom
                        new FeatureHasher() // Characteristic hash
        // Build feature engineering pipeline
        // fit feature pipeline model
        PipelineModel featurePipelineModel =;
        // prepare stream train data
        CsvSourceStreamOp data = new CsvSourceStreamOp()
        // The original training data and original prediction data are obtained by real-time segmentation of stream data sources
        // split stream to train and eval data
        SplitStreamOp splitter = new SplitStreamOp().setFraction(0.5).linkFrom(data);
        // A logistic regression model is trained as the initial model of FTRL algorithm, which is required for the cold start of the system.
        // train initial batch model
        LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp()
        BatchOperator<?> initModel = featurePipelineModel.transform(trainBatchData).link(lr);

        // FTRL online training based on initial model
        // ftrl train
        FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel)
        // On the basis of FTRL online model, connect prediction data for prediction
        // ftrl predict
        FtrlPredictStreamOp predictResult = new FtrlPredictStreamOp(initModel)
                .setReservedCols(new String[]{labelColName})
                .linkFrom(model, featurePipelineModel.transform(splitter.getSideOutput(0)));
        // Evaluate the forecast result flow
        // ftrl eval
                        new EvalBinaryClassStreamOp()
                        new JsonValueStreamOp()
                                .setReservedCols(new String[]{"Statistics"})
                                .setOutputCols(new String[]{"Accuracy", "AUC", "ConfusionMatrix"})
                                .setJsonPath(new String[]{"$.Accuracy", "$.AUC", "$.ConfusionMatrix"})

0x03 problem

It is better to use questions to guide the analysis. Here are some questions we can easily think of.

  • Are there prefabricated models in both the training phase and the prediction phase to cope with the "cold start"?
  • How do the training phase and the prediction phase relate?
  • How to transfer the trained model to the prediction stage?
  • When outputting a model, what should I do if the model is too large?
  • What mechanism is used to update the online training model? Is it a timed drive update?
  • During the model loading process in the prediction phase, can you still predict? Is there a mechanism to ensure that it can be predicted during this period?
  • What stages of training and prediction use parallel processing?
  • How to deal with high-dimensional vectors? Cut it apart?

We will explore these issues one by one.

0x04 overall logic

Online training is implemented in the FtrlTrainStreamOp class, and its linkFrom function implements the basic logic.

The main logic is:

  • 1) Load the initialization model to the dataBridge; dataBridge = directreader Collect (model);
  • 2) Get relevant parameters. For example, the vectorSize is 30000 by default. Is there a hasInterceptItem;
  • 3) Get segmentation information. splitInfo = getSplitInfo(featureSize, hasInterceptItem, parallelism); It will be used soon.
  • 4) Segment high-dimensional vectors. If the initialization data is hashed, high-dimensional vectors will be generated. Here, cutting is required. Initdata Flatmap (New splitvector (splitinfo, hasinterceptitem, vectorsize, vectortrainidx, featureidx, labelidx));
  • 5) Build an iterativestream Connectediterativestreams iteration, which will build (or connect) two data streams: feedback stream and training stream;
  • 6) iterativeBody is built with iteration, which includes two parts: CalcTask and ReduceTask;
    • CalcTask is divided into two parts. flatMap1 is the predict ion required for FTRL iteration of distribution calculation, and flatMap2 is the update parameter part of FTRL;
    • ReduceTask has two functions: merge these predict calculation results / merge models if conditions are met & output models to downstream operators;
  • 7) Result = iterativebody Filter; Basically, the judgment is based on the time interval (it can also be considered as time driven). The data of "time not expired & vector meaningful" will be sent back to the feedback data stream, continue to iterate, and return to step 6), and enter flatMap2;
  • 8) Output = iterativebody Filter; The data that meets the standard (the time has expired) will jump out of the iteration, and then the algorithm will call WriteModel to convert LineModelData into multiple rows and forward them to the downstream operator (that is, the online prediction stage); That is to update the model to the online prediction stage regularly.

The code summary is:

public FtrlTrainStreamOp linkFrom(StreamOperator<?>... inputs) {
    // 3) Get segmentation information
    final int[] splitInfo = getSplitInfo(featureSize, hasInterceptItem, parallelism);

    DataStream<Row> initData = inputs[0].getDataStream();

    // 4) Segment high-dimensional vectors.
    // Tuple5<SampleId, taskId, numSubVec, SubVec, label>
    DataStream<Tuple5<Long, Integer, Integer, Vector, Object>> input
        = initData.flatMap(new SplitVector(splitInfo, hasInterceptItem, vectorSize,
        vectorTrainIdx, featureIdx, labelIdx))
        .partitionCustom(new CustomBlockPartitioner(), 1);

    // train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>
    // feedback format = Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>
    // 5) Build an iterativestream Connectediterativestreams iteration, which will build (or connect) two data streams: feedback stream and training stream;
    IterativeStream.ConnectedIterativeStreams<Tuple5<Long, Integer, Integer, Vector, Object>,
        Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
        iteration = input.iterate(Long.MAX_VALUE)
            .of(new TypeHint<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {}));

    // 6) iterativeBody is built with iteration, which includes two parts: CalcTask and ReduceTask;
    DataStream iterativeBody = iteration.flatMap(
        new CalcTask(dataBridge, splitInfo, getParams()))
        .flatMap(new ReduceTask(parallelism, splitInfo))
        .partitionCustom(new CustomBlockPartitioner(), 1);

    // 7) Result = iterativebody Filter; Basically, the judgment is based on the time interval (it can also be considered as time driven). The data of "time not expired & vector meaningful" will be sent back to the feedback data stream, continue to iterate, and return to step 6), and enter flatMap2;
    DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
        result = iterativeBody.filter(
        new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {
            public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> t3)
                throws Exception {
                // if t3.f0 > 0 && t3.f2 > 0 then feedback
                return (t3.f0 > 0 && t3.f2 > 0);

    // 8) Output = iterativebody Filter; The data that meets the standard (the time has expired) will jump out of the iteration, and then the algorithm will call WriteModel to convert LineModelData into multiple rows and forward them to the downstream operator (that is, the online prediction stage); That is to update the model to the online prediction stage regularly.
    DataStream<Row> output = iterativeBody.filter(
        new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {
            public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value)
                throws Exception {
                /* if value.f0 small than 0, then output */
                return value.f0 < 0;
        }).flatMap(new WriteModel(labelType, getVectorCol(), featureCols, hasInterceptItem));
    // Specifies that a flow will be the end of the iteration process, and this flow will be fed back to the iteration as the second input

    TableSchema schema = new LinearModelDataConverter(labelType).getModelSchema();
    this.setOutput(output, names, types);
    return this;

To facilitate reading, we give the following flow chart (split training data set / test data set is omitted here):

Forgive me for using this method to draw pictures, because I hate to see a good article, but I find that the pictures are gone

       │          Initial model training phase                         │
       │                                                │
┌─────────────────┐                              ┌─────────────────┐ 
│ trainBatchData  │                              │ trainStreamData │
└─────────────────┘                              └─────────────────┘ 
       │                                                │
       │                                                │ 
┌──────────────────┐                                    │ 
│ featurePipeline  │                                    │  
└──────────────────┘                                    │   
       │                                                │
       │                                                │
┌─────────────┐                                         │
│ linear regression model   │                                         │ 
└─────────────┘                                         │ 
       │                                                │
       │                                                │
       │          Online training phase                            │
       │                                                │
┌─────────────┐                                 ┌──────────────────┐
│ dataBridge  │  Load initialization model                   │ featurePipeline  │ 
└─────────────┘                                 └──────────────────┘
       │                                                │
       │                                                │
       │                                                │
┌─────────────┐                              ┌──────────────────────────┐ 
│ Get segmentation information  │ getSplitInfo                 │ inputs[0].getDataStream()│
└─────────────┘                              └──────────────────────────┘
       │                                                │
       │                                                │
       │                                                │
       │  SplitInfo                                     │  
       │                                                │ 
       │                                                │ 
┌──────────────────────────┐      feature vector                │ 
│ SplitVector              │ <--------------------------│ 
       │ Parse input to get DataStream<Tuple5<SampleId, taskId, numSubVec, SubVec, label>> input
│ <Tuple5,Tuple7> iteration │ Iterative build, two inputs train data Tuple5<>,feedback data Tuple7<>
       │ CalcTask It is logically divided into two modules: flatMap1, flatMap2
┌───────────────────┐                     ┌───────────────────┐ 
│ CalcTask.flatMap1 │ input Tuple5<>         │CalcTask.flatMap2  │ input Tuple7 <--------------- 
└───────────────────┘                     └───────────────────┘                           │
       │ distributed computing  FTRL In algorithm predict part            │ Distributed processing feedback data/Update parameters/Issue after cumulative parameters expire  │
       │                                         │                                        │
       │                                         │                                        │
       │<-----------------------------------------                                        │
       │ Above two flatmap All output to the following ReduceTask                                               │
       │                                                                                  │ 
       │                                                                                  │ 
┌──────────────────────┐                                                                  │
│ ReduceTask.flatMap   │ 1. If the time expires&Complete collection and merge/Output model(value.f0 < 0)              │
└──────────────────────┘ 2. Not expired, merge each CalcTask Calculated predict,Form a lable y            │ 
       │                                                                                  │ 
       │                                                                                  │ 
┌────────────────────┐                                                                    │ 
│ result = filter    │  if t3.f0 > 0 && t3.f2 > 0 or not ?                                │
└────────────────────┘                                                                    │
      │                                                                                   │ 
      │                                                                                   │ 
      │                                                                                   │
      │    if t3.f0 > 0 && t3.f2 > 0 then         ┌───────────────────┐                   │
      │------------------------------------------>│CalcTask.flatMap2  │output Tuple7 ---------
      │   "Time not expired&Vector meaningful" Feedback will be sent back to continue iteration   └───────────────────┘
      │ If no feedback data flow is formed, continue filtering  
│ output = filter    │ if value.f0 small than 0 or not ?
      │    if value.f0 small than 0, then output  
      │    Data meeting the standard (time expired) will jump out of the iteration and output the model
│ WriteModel │ because filter out,So periodically output the model
      │           Online prediction stage
      │                                      ┌─────────────────┐ 
      │                                      │ testStreamData  │
      │                                      └─────────────────┘
      │                                             │       
      │                                             │  
      │                                             │           
┌──────────────┐                              ┌──────────────────┐ 
│ FTRL Predict │ <----------------------------│ featurePipeline  │ 
└──────────────┘                              └──────────────────┘     

0xFF reference

[machine learning] logistic regression (very detailed)

Logistic regression

[machine learning] distributed (parallel) implementation of LR

Parallel logistic regression

Discussion on machine learning algorithm and Its Parallelization

Online LR - FTRL algorithm understanding

Principle and implementation of online optimization algorithm FTRL

Principle and engineering implementation of LR+FTRL algorithm

Iterative API analysis of Flink stream processing

FTRL formula derivation

FTRL paper notes

Introduction to ftrl (follow the regulated leader) algorithm for online machine learning

FTRL code implementation

LR+FTRL of FTRL practice (dense data adopted by code)

Online learning algorithm ftrl proximal principle

Online CTR prediction algorithm based on FTRL

Ftrl proximal of CTR prediction algorithm

Detailed explanation of online learning algorithm FTRL widely used by major companies

Online optimization 5: FTRL

Follow the regulated leader (ftrl) algorithm summary

Tags: Big Data Machine Learning source code analysis

Posted by Aretai on Mon, 30 May 2022 19:28:46 +0530