A Context-aware Capsule Network for Multi-label Classification
AA Context-aware Capsule Network forMulti-label Classification
Sameera Ramasinghe , , C.D. Athuraliya , and Salman H. Khan ConscientAI Labs, Colombo, Sri Lanka Australian National University, Canberra, Australia [email protected]
Abstract.
Recently proposed Capsule Network is a brain inspired ar-chitecture that brings a new paradigm to deep learning by modellinginput domain variations through vector based representations. Despitebeing a seminal contribution, CapsNet does not explicitly model struc-tured relationships between the detected entities and among the capsulefeatures for related inputs. Motivated by the working of cortical net-work in human visual system, we seek to resolve CapsNet limitations byproposing several intuitive modifications to the CapsNet architecture. Weintroduce, (1) a novel routing weight initialization technique, (2) an im-proved CapsNet design that exploits semantic relationships between theprimary capsule activations using a densely connected Conditional Ran-dom Field and (3) a Cholesky transformation based correlation moduleto learn a general priority scheme. Our proposed design allows CapsNetto scale better to more complex problems, such as the multi-label classi-fication task, where semantically related categories co-exist with variousinterdependencies. We present theoretical bases for our extensions anddemonstrate significant improvements on ADE20K scene dataset.
After nearly two decades since its inception, convolutional neural networks (CNNs)[1] have eventually become the norm for computer vision tasks. Vision tasks thatwidely use CNNs include object recognition [2,3], object detection [4,5] and se-mantic segmentation [6,7]. Despite their popularity and high effectiveness inmost vision tasks, previous works have pointed out several limitations of CNNsin vision applications. One major limitation is the notable trade-off betweenpreserved spatial information and the transformation invariance with poolingoperations. Furthermore, CNNs marginally tackle rotational invariance.To overcome aforementioned limitations in CNNs, recently introduced Cap-sule Networks (CapsNets) [8] propose a novel deep architecture for feature ab-straction while preserving underlying spatial information. This architecture ismotivated by human brain function and suggests equivariance over invariancewhile demonstrating comparable performance on digit classification with MNISTdataset [9]. These early results of CapsNet manifest a new direction for futuredeep architectures. However to our knowledge, CapsNet architecture has not a r X i v : . [ c s . C V ] O c t Sameera Ramasinghe, C.D. Athuraliya and Salman H. Khan been used for larger and complex datasets, specifically for multi-label classifica-tion tasks where the goal is to tag an input image with multiple object categories.This is due to the reason that original CapsNet does not incorporate contextualinformation necessary for complex tasks such as multi-label classification. In thiswork we evaluate the original CapsNet architecture on a large image dataset withover 150 object classes that appear in complex real-world scenes. We then pro-pose a new context-aware CapsNet architecture that makes informed predictionsby exploiting semantic relationships of object classes as well as underlying cor-relations of low-level capsules. Our model is inspired by the working of humanbrain where contextual and prior information is effectively modeled [10].To enable faster training on large datasets, we first propose a novel weightinitialization scheme based on trainable parameters with back-propagation. Thisupdate allows initial routing weights to capture low-level feature distributionsand improves the convergence rate and accuracy compared to equal routingweight initialization of the original CapsNet.
Second , we argue that the corre-sponding elements of primary capsule predictions are interrelated since primarycapsule predictions encapsulate the attributes of object classes. In simple terms,this means that the presence of object attributes (such as position, rotationand texture) in one capsule’s output are dependent on similar attributes thatare detected by neighbouring capsules. This property was not utilized in theoriginal CapsNet architecture. To characterize this, we introduce an end-to-endtrainable Conditional Random Field (CRF) to encourage network predictionsto be more context specific.
Third , the original CapsNet captures the prioritybetween primary and decision capsules independently for each data point. Weargue that there exists a general priority scheme between decision and primarycapsules, which is distributed across the dataset. Therefore, we propose a cor-relation module to capture the overall priority of primary capsule predictionsthroughout the dataset that effectively encapsulates broader context.We apply proposed architecture for multi-label classification on a large scenedataset, ADE20K [11], and report significant improvements over the originalCapsNet architecture.
Hinton et al. [12] first proposed capsule as a new module in deep neural networksby transforming auto-encoders architecture. Capsules were suggested as an alter-native to widely adapted subsampling layers of CNNs and to encapsulate moreprecise spatial relationships. Sabour et al. [8] recently proposed a complete neuralnetwork architecture for capsules with dynamic routing and a reconstruction loss.They demonstrated state of the art performance on MNIST dataset [9]. Theyalso outperformed existing CNN architectures on a new dataset, MultiMNIST[8], created by overlaying one digit on top of another digit from a different class.More recently, Hinton et al. [13] proposed an updated capsule architecture witha logistic unit and a new iterative routing procedure between capsule layersbased on the Expectation-Maximization (EM) algorithm [14]. This new capsule
Context-aware Capsule Network for Multi-label Classification 3 architecture significantly outperformed baseline CNN models on small-NORBdataset [15] and reported that the new architecture is less vulnerable to whitebox adversarial attacks. Xi et al. [16] extended initial CapsNet work by utilizingit on CIFAR10 classification task. However, CapsNet has not been used beforefor complex structured prediction tasks and our work is a key step towards thisdirection.
A decision capsule is considered to be a complete representation of an objectclass. That means each of its scalar element describes a certain attribute ofan object class such as rotation or position. These attributes may not be se-mantically meaningful, but an object can be completely reconstructed using theelements of the corresponding capsule. Each corresponding element of differentdecision capsules represents similar attributes of different objects. For example,the i th scalar element of j th decision capsule may represent the rotation of achair, while i th scalar element of ( j + 1) th decision capsule may describe therotation of a desk.The predictions by primary capsules for decision capsules encapsulate theattributes of an object class. Therefore, the corresponding elements of outputsfrom primary capsules are conditioned upon each other. For example, there maybe a hidden condition such that if the primary capsule is in state A , a chair cannotbe rotated in α direction when a spatially nearby desk is rotated in β direction.To exploit this behavior we feed primary capsule predictions to an end-to-endtrainable CRF module to learn the inter-dependencies among attributes.Here, CRF module is used as a structured prediction mechanism for eachprimary capsule to conditionally alter its predictions. Thus the CRF is ableto capture semantic relationships across object classes. Moreover, we introducea correlation module which can prioritize predictions by primary capsules andeffectively predict decision capsules. The overall architecture is illustrated inFigure 1. We first begin with the description of routing weight initializationand then explain the densely connected CRF and the correlation module insubsequent sections. In the original CapsNet, primary capsules can be interpreted as a set of Z stackedfeature maps. Each primary capsule element can be considered as a part of alow-level feature. Following this assumption we rearrange primary capsules as a N × N × D grid where N × N × D is the total number of primary capsules. Eachitem in the grid is a capsule with I dimensions. Hence, D = Z/I .Instead of initializing routing weights equally, we modify the initial routingweights as trainable parameters and use backpropagation to train them. Thisforces the initial routing weights to be dependent on the low-level feature distri-bution resulting faster convergence.
Sameera Ramasinghe, C.D. Athuraliya and Salman H. Khan
PrimaryCapsules DecisionCapsulesPrimary CapsulePredictionsInput Conv 1 RoutingParameters . . .
CRFModule CorrelationModuleRouting I n i t i a li z e R ou t i ng W e i gh t s Fig. 1.
Proposed CapsNet architecture
To this end, we first define a statistical value per capsule to represent its ele-ment distribution. Let K and J be the number of primary and decision capsulesrespectively, and C = (cid:8) c , c , . . . , c K (cid:9) be the set of primary capsules. Then wemap the capsules to a set S = (cid:8) s , s , . . . , s K (cid:9) where s k = µ k max ( σ k ,(cid:15) ) , ∀ < k CRF as a stack of CNN layers1: H kj ( i ) = X ik exp ( E u ( P k,j ( i )) ∀ i, j, k (cid:46) Initialization2: for itr = 0 to MaxItr do3: ¯ H kj ( i ) = (cid:80) j (cid:48) E p ( H kj ( i ) , H kj (cid:48) ( i )) (cid:46) Calculation of pair-wise potentials4: ˜ H kj ( i ) = H kj ( i ) − ¯ H kj ( i ) (cid:46) Addition of pair-wise potentials to unarypotentials5: H kj ( i ) = X ik e ˜ H kj ( i ) (cid:46) Normalization6: end for The first line is the initialization. Here, X i,k = (cid:80) Jj =0 e P kij where J is thenumber of decision capsules. Since E u ( P k,j ( i )) is the cost of the i th element ofthe prediction, we can treat the predicted element as P k,j = − E u ( P k,j ( i )). Thisis equivalent to applying the softmax function over each set of i th elements ofthe predictions by k th primary capsule for j th decision capsules. Line number3 illustrates the cost of pair-wise potentials. Instead of deriving the pair-wisepotential function manually, using back-propagation to find optimum mappingis both effective and efficient. Since all the corresponding element pairs have tobe taken into account, we apply a fully connected layer on top of the predictionsto learn this pair-wise potential function. Since we are minimizing Z ( x ) k,i foreach i and k independently, these layers are not connected across i or k , whichreduces the computational complexity significantly. Line number 4 illustratesadding the unary potentials to pair-wise potentials. Line number 5 is equivalentto applying softmax function over the outputs. In the CapsNet architecture, each primary capsule has a unique prediction foreach decision capsule. Since primary capsules are essentially a set of low-levelfeatures, this can be viewed as each low level feature estimating the state of the Sameera Ramasinghe, C.D. Athuraliya and Salman H. Khan output class. Moreover, each low-level feature priority depends on the outputclass. For example, a circle detector may perform better in predicting the stateof a wheel, while a horizontal edge detector may perform better in predictingthe state of a bridge.The original routing technique tries to capture these varying priorities ofprimary capsules with respect to decision capsules by a weighted sum of pre-dictions. The routing weights are adjusted in the next iteration according tothe similarity between primary capsule predictions and the decision capsule ofthe current iteration. The magnitude of similarity is estimated by dot product.Following this method the network learns the priorities independently for eachdata point. However, we argue that there is also a general priority scheme thatis distributed across the whole dataset, that can be learned during the train-ing. Therefore, we propose a novel correlation based approach to discover thesepriorities and estimate final prediction.Unlike the CRF module, our objective here is to find correlation between theattribute distributions of corresponding predictions of primary capsules and adecision capsule, instead of finding the dependency between each single corre-sponding attribute of predictions. Given a set of predictions for a specific deci-sion capsule, the goal of the correlation module is to find the decision capsuleelements by exploiting priorities of each primary capsule. The correlation coeffi-cients between a decision capsule and a primary capsule predictions are learnedthroughout the training. Furthermore, these correlation coefficients should de-pend on the low-level feature distribution and also should be trainable. To thisend, we use a property of Cholesky transformation [18] and derive a genericfunction to achieve this task.Let two distributions be Q and R . Cholesky transformation ensures, ¯ Q ¯ R = ρ (cid:112) − ρ QR (3)¯ Q = R, ¯ R = ρ Q + (cid:112) − ρ R (4)and produces two distributions ¯ Q , ¯ R which are correlated by a factor of ρ .Likewise, ¯¯ Q ¯¯ R = ρ (cid:112) − ρ RQ (5)¯¯ Q = Q, ¯¯ R = ρ R + (cid:112) − ρ Q (6)produces two distributions ¯¯ Q , ¯¯ R which are correlated by a factor of ρ . Thereforeif we choose, ρ = (cid:112) − ρ (7)we get T = ¯ R = ¯¯ R , where T and R are correlated by ρ and, T and Q arecorrelated by ρ . Using this property and considering two component distribu-tions D = P k,j and D = P k (cid:48) ,j , where P k,j is the component distribution of k th Context-aware Capsule Network for Multi-label Classification 7 primary capsule prediction for j th decision capsule, we obtain a new distributionˆ D , satisfying ρ ˆ D,D = ρ , and ρ ˆ D,D = αρ . Here, ρ x ,x denotes the correlationbetween the two particular distributions x and x . Using Equation 7, ρ α = (cid:112) − ρ , ρ = α √ α (8)ˆ D = (cid:2) α √ α D + 1 √ α D (cid:3) (9)Using Equation 9, we define a recursive function f ρ to obtain a correlated elementdistribution. f ρ ( P ,j | P ,j . . . , P k,j , . . . , P K,j ) = α K (cid:112) α K f ρ ( P ,j | P ,j . . . , P k,j , . . . , P K − ,j )+ P K,j (cid:112) α K , ∀ < k ≤ K, < j ≤ J (10)where f ρ ( P ,j | P ,j ) = (cid:2) α √ α P ,j + √ α P ,j (cid:3) . Using this derivation, we obtainthe j th decision capsule C j = f ρ ( P ,j | P ,j . . . , P k,j , . . . , P K,j ). Here, α requires betrainable and dependent on low-level feature distributions. Since the above oper-ation is differentiable, the first criteria is fulfilled. To enforce α to be dependenton low-level features, we use the following method.Consider a N × N low-level feature map. Since we need J ( K − 1) trainableparameters as per Equation 10, we convolve this particular feature map with aset of J ( K − 1) kernels with sizes N × N each. This outputs J ( K − 1) numberof scalar values, which can be used as α parameters. We conduct experiments to demonstrate the effectiveness of each of the im-provements; new initialization scheme of routing weights, CRF module and thecorrelation module. We use mean average precision (mAP) as the evaluation met-ric throughout the experiments with precision threshold 0.5. We use ADE20Kdataset to evaluate the proposed architecture given its complex scenes and richmulti-label annotations for training images. ADE20K provides over 20 , 000 train-ing and testing images annotated with 150 semantic object categories. The goal of replacing the equal initialization of routing weights with trainableweights is faster convergence. In order to test the significance of this, we trainthe proposed architecture with and without the trainable initial routing schemeand test the validation mAP. The results are illustrated in Figure 2.As shown in Figure 2 the validation mAP stabilizes around 15 th epoch forthe CapsNet without the proposed routing weight initialization method. On thecontrary, the CapsNet with the proposed routing weight initialization methodstabilizes around 9 th epoch. Therefore it is evident that the proposed method isable to achieve faster convergence compared to equal initial routing weights. Sameera Ramasinghe, C.D. Athuraliya and Salman H. Khan As shown in Figure 2 the validation mAP stabilizes around 15 th epoch forthe CapsNet without the proposed routing weight initialization method. On thecontrary, the CapsNet with the proposed routing weight initialization methodstabalizes around 9 th epoch. Therefore it is evident that the proposed methodis able to achieve faster convergence compared to equal initial routing weights. m A P ( v a li d a t i o n ) without proposed routing weight initializationwith proposed routing weight initialization Fig. 2. Evaluation of convergence gain bytrainable initial routing weights. Method mAPOriginal CapsNet 42.38Ours (RW + CRF) 52.50Ours (RW + CRF + CORR) Comparison of the proposed ar-chitecture with the baseline We compare the original CapsNet architecture with the proposed architectureby measuring the mean average precision over the validation images. Table 1shows the comparison results. We gain a significant 14 . 33 mAP gain over thetotal 150 object classes compared to the original architecture. Furthermore, Inthis experiment, we demonstrate the added performance gains by the CRF andcorrelation modules, and show that each of the modules provide complementaryimprovements. We gain an improvement of 10 . 12 mAP by adding the CRF mod-ule and a 4 . 21 mAP improvement by adding the correlation module on top ofthe CRF module. All the architectures are trained for 20 epochs. CapsNet can be considered as a fresh look at deep neural architectures that havebecome immensely successful in recent years. The capsule design has opened upnew research directions in deep learning and it possesses several limitations inits current form. In this work we attempt to overcome several limitations byintroducing a new improved CapsNet architecture. Our objective of this work istwo fold: effectively capture complex interactions between primary capsules andleverage data wide correlations between representations of similar inputs.To achieve these objectives we introduce three new changes to original Cap-sNet. Firstly we propose trainable routing weight initialization that can betrained using backpropergation. This replaces existing equal initial routing weights Fig. 2. Evaluation of convergence gain bytrainable initial routing weights. Method mAPOriginal CapsNet 42.38Ours (RW + CRF) 52.50Ours (RW + CRF + CORR) Comparison of the proposedarchitecture with the baseline We compare the original CapsNet architecture with the proposed one by mea-suring the mAP measure. Table 1 shows the comparison results. We gain asignificant 14 . 33 mAP gain over total 150 object classes compared to the origi-nal architecture. Furthermore, we demonstrate performance gains by CRF andcorrelation modules, and show that each module provides complementary im-provements. We gain an improvement of 10 . 12 mAP by adding the CRF moduleand a 4 . 21 mAP improvement by adding the correlation module on top of theCRF module. All the architectures are trained for 20 epochs. In this work we attempt to overcome several limitations of CapsNet by intro-ducing an improved architecture inspired by the contextual modeling in visualcortex [10]. Our objective is two fold: effectively capture complex interactionsbetween primary capsules and leverage data wide correlations between represen-tations of similar inputs. To this end, we introduced three novel ideas. Firstly,we proposed a new routing weight initialization that can be trained using back-propagation. This replaced existing equal initial routing weights with a moreintuitive and efficient technique. Secondly, we introduced a CRF based methodto exploit conditional attributes of primary capsule predictions to capture thecontext of neighbouring objects. Thirdly, we proposed a correlation module tolearn dataset-wise priority scheme instead of capturing the priority separately foreach data point. As demonstrated through our experiments, these improvementsin CapsNet model design contributes to a substantial accuracy improvement ofover 33% in multi-label classification on a challenging dataset. Context-aware Capsule Network for Multi-label Classification 9 References 1. LeCun, Y., Bottou, L., Bengio, Y., Haffner, P.: Gradient-based learning applied todocument recognition. In: Proceedings of the IEEE. (1998) 2278–23242. Krizhevsky, A., Sutskever, I., Hinton, G.E.: Imagenet classification with deepconvolutional neural networks. In: Advances in Neural Information ProcessingSystems. (2012) 1097–11053. Szegedy, C., Liu, W., Jia, Y., Sermanet, P., Reed, S., Anguelov, D., Erhan, D.,Vanhoucke, V., Rabinovich, A.: Going deeper with convolutions. In: Proceedingsof the IEEE Conference on Computer Vision and Pattern Recognition. (2015) 1–94. Ren, S., He, K., Girshick, R., Sun, J.: Faster r-cnn: Towards real-time object detec-tion with region proposal networks. In: Advances in Neural Information ProcessingSystems 28. (2015) 91–995. Redmon, J., Divvala, S.K., Girshick, R.B., Farhadi, A.: You only look once: Unified,real-time object detection. arXiv preprint arXiv:1506.02640 (2015)6. Long, J., Shelhamer, E., Darrell, T.: Fully convolutional networks for semanticsegmentation. In: Proceedings of the IEEE Conference on Computer Vision andPattern Recognition. (2015) 3431–34407. Badrinarayanan, V., Kendall, A., Cipolla, R.: Segnet: A deep convolutionalencoder-decoder architecture for image segmentation. IEEE Transactions on Pat-tern Analysis and Machine Intelligence (2017)8. Sabour, S., Frosst, N., Hinton, G.E.: Dynamic routing between capsules. In:Advances in Neural Information Processing Systems. (2017) 3859–38699. LeCun, Y., Cortes, C., Burges, C.J.: The mnist database of handwritten digits.(1998)10. Bar, M.: Visual objects in context. Nature Reviews Neuroscience5