Kalman Filter Modifier for Neural Networks in Non-stationary Environments
KKalman Filter Modifier for Neural Networks inNon-stationary Environments
Honglin Li
CVSSP, University of Surrey [email protected]
Frieder Ganz
Adobe Systems Inc [email protected]
Shirin Enshaeifar
CVSSP, University of Surrey [email protected]
Payam Barnaghi
CVSSP, University of Surrey [email protected]
Abstract
Learning in a non-stationary environment is an inevitable problem when applyingmachine learning algorithm to real world environment. Learning new tasks withoutforgetting the previous knowledge is a challenge issue in machine learning. Wepropose a Kalman Filter based modifier to maintain the performance of NeuralNetwork models under non-stationary environments. The result shows that ourproposed model can preserve the key information and adapts better to the changes.The accuracy of proposed model decreases by 0.4% in our experiments, while theaccuracy of conventional model decreases by 90% in the drifts environment.
In many real-world scenarios, the underlying process of data stream is non-stationary. The perfor-mance of neural networks may be decreased when the source distribution of the input changes. Evenworse, these changes may lead to catastrophic forgetting problem. An important issue for neuralnetworks is the ability to preserve previously learned information and to adapt faster to the changesin dynamic environment. We aim to solve the problem of forgetting previously learned informationunder data drifts. According to the Bayesian Decision Theory [1] a classification problem can bedefined as maximising the posterior probability of P ( y | X ) , where y represents the classes of data( X ). We can view the drifts into two types [2, 3]: real drift which refers to changes in P ( y | X ) whichmeans the data distribution remains the same, but the class of the data changes; virtual drift whichrefers to the changes in P ( X ) in which the class of the data remains the same, but the distribution ofthe data changes. Real drift and virtual drift need replacement learning and supplemental learningrespectively[4]. In real world, we cannot choose the learning environment, a robust model is neededto adapt to these drifts. We propose a model with a Kalman Filter modifier to adjust the learningparameters of the neural network models. Our experiments show that our purposed solutions adaptsbetter and faster compared to the conventional neural network models in drifts environment. The performance of an online learning model would decrease if the training data or its distributionchanges. From the gradient point of view, the location of optima point will change. For example, ifwe take a single parameter as shown in Figure (1), the parameters in the pre-trained model have aminimum loss. However, the loss will be relatively high when the distribution of the data changes. Inthis case, the parameters of the online learning model will change to the new optima value. However,
Preprint. Work in progress. a r X i v : . [ c s . L G ] N ov his changes could be a significant problem in training a consistent model when the data periodicallygoes back to the previous state or distribution[5]. We address this problem by finding an optimal (a) Optima Value of task A (b) Model move from A to task B (c) Estimate Gradient Figure 1: (a) The parameter of the model will be around the red point when the model is converged.The red line indicates the gradient of the red point. (b) The training task changes from A to B. Thered point is no longer an optima value. The black line shows the parameter has a tendency of movingto black point. (c) The red line(Pre_G) is the gradient of the model on task A, the black one(Cur_G)is the gradient of the model on task B. Our proposed method can estimate a new value(New_G) basedon these two gradients.estimation between the data and the changes caused by the drifts. For this purpose we train a Kalmanfilter which acts as an optimal estimator[6]. This method can infer parameters from uncertain aninaccurate observations. m k = Am k − + Bu k + w k , z k = Hm k + v k (1)In our approach, we use a mini-batch method to train the model. In Equation (1) m k is the k th stateof model, the initial state m represents the pre-trained model. Each state means model is training on k th batch of the new data. z k refers to the output model. From the gradient descent algorithm view,in the Equation (1), A = I , B is the learning rate, u k is the gradient of model on D k which refers tothe k th batch of data, H = I . Because we omit the process noise, the w k , v k are 0. This is how theneural networks perform with gradient descent algorithm in a linear state perspective.However, the model trained this way is not an optimal or accurate model. Based on the gradientalgorithm, we can obtain the gradient of the pre-trained model which will indicate the measurementerror of the current model. Using a Kalman Filter, we have the following: ˆ x k | k − = x k − P k | k − = P k − (2)We omit all the process noise, and assumes that the system is stable (the dynamic matrix is 0). InEquation (2), the first formula is state predict process, the second one is state error predict process.Where ˆ x k | k − is the predicted model parameters at k th state, which we assumes it is stable with noadditional information given, P k is the state error, which assumes it is stable as well. K k = P k | k − / ( P k | k − + R )ˆ x k = ˆ x k | k − + K k ( z k − ˆ x k | k − ) P k = ( I − K k ) P k | k − (3)Equation (3) is the update process. Where R is measurement noise (gradient of m k on D k ), K k isthe Kalman Gain at k th state, ˆ x k is the estimated parameters of the model. Algorithm (1) summarisesthe proposed steps. Result:
Model M k initialisation; P = gradients( m on pre_trained dataset) m = Pre_trained Model while train on new task do Produce model m k : m k − Train on D k Calculate R : R = gradients( m k on D k )Kalman Gain: K k = P k / ( P k − + R ) Produce model ˆ m k : ˆ m k = Kalman Filter([ m k − , P k − ],[ m k , P k ])Update P k = ( I − K k ) P k end Algorithm 1: The Kalman Filter based Algorithm2
Experiment Results
We train a fully-connected multi-layer neural network[7] on three datasets sequentially. Within eachtask, the model is trained at fixed epochs and the training data will be no longer available to the model.We constructed a set of classification tasks based on the MNIST dataset[8] (Figure 2). The data in thefirst task is the original MNIST dataset. In the second task we permute all the pixels of the images.This will require a completely different solution. The final task is related to the real drift problem.For this purpose, we change all the labels by adding 1 to the value of the label (e.g. if the image is3, the label will be changed to 4). The results show that, no matter what the training dataset is, theKalman filter will allow the online learning model to respond more efficiently to the changes and tomaintain an overall better performance compared to a conventional model without any modifiers (seeFigure 3). (a) Original (b) Permuted (c) Label changed
Figure 2: (a) The pre-trained MNIST dataset. (b) Permuted MNIST dataset. (c) The training imagesare the same, but the labels are increased by one; in this example, the label of the image will be 4. (a) Progression of accuracy on validation data (b) Progression of accuracy on test data
Figure 3: (a) The accuracy of the model using pre-trained validation dataset. (b) The accuracy ofthe model using pre-trained test dataset after being trained on new tasks. The oranges line are theaccuracy of the proposed model, the blues line are the accuracy of conventional neural networks. Thered dashed lines indicate the change point in the tasks. Each dashed line shows a drifts.
We present a novel online learning method that responds to data drifts by using a Kalman Filtermodifier. This addresses the forgetting problem for neural networks in non-stationary environments.Our proposed method does not require any changes in the architecture of neural network. We usethe Kalman filter to adjust the learning parameters in changing environments. The Kalman Filtermodifier takes the weights as the measurement value and the gradient as the measurement error. Wedemonstrate our approach using both virtual and real drifts and show that the proposed model willremember the previously learned information to adjust the online learning parameters. The method ischaracterised by an intrinsic recursive algorithm; so it does not need to access the previously seendata. Our evaluation results show that our approach performs better in responding to changes andhas lower learning error compared with a conventional model. Our future work focus on improvingthe Kalman Filter and compare it with some advanced catastrophic forgetting methodologies innon-stationary environment. 3 cknowledgments
This work is partially supported by the EU H2020 IoTCrawler project under contract number: 779852.
References [1] R. O. Duda, P. E. Hart, D. G. Stork et al. , “Pattern classification,”
J. of Computational Intelligenceand Applications , vol. 1, pp. 335–339, 2001.[2] A. Tsymbal, “The problem of concept drift: definitions and related work,”
Computer ScienceDepartment, Trinity College Dublin , vol. 106, no. 2, 2004.[3] I. Žliobait˙e, “Learning under concept drift: an overview,” arXiv preprint arXiv:1010.4784 , 2010.[4] R. Elwell and R. Polikar, “Incremental learning of concept drift in nonstationary environments,”
IEEE Transactions on Neural Networks , vol. 22, no. 10, pp. 1517–1531, 2011.[5] J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan,J. Quan, T. Ramalho, A. Grabska-Barwinska et al. , “Overcoming catastrophic forgetting in neuralnetworks,”
Proceedings of the national academy of sciences , p. 201611835, 2017.[6] M. B. Rhudy, R. A. Salguero, and K. Holappa, “A kalman filtering tutorial for undergraduatestudents,”
International Journal of Computer Science & Engineering Survey (IJCSES) , vol. 8, pp.1–18, 2017.[7] J. Schmidhuber, “Deep learning in neural networks: An overview,”
Neural networks , vol. 61, pp.85–117, 2015.[8] Y. LeCun, C. Cortes, and C. Burges, “Mnist handwritten digit database,”