Exploiting Shared Representations for Personalized Federated Learning
Liam Collins, Hamed Hassani, Aryan Mokhtari, Sanjay Shakkottai
EExploiting Shared Representations for Personalized FederatedLearning
Liam Collins ∗ , Hamed Hassani † , Aryan Mokhtari ∗ , Sanjay Shakkottai ∗ Abstract
Deep neural networks have shown the ability to extract universal feature representationsfrom data such as images and text that have been useful for a variety of learning tasks.However, the fruits of representation learning have yet to be fully-realized in federatedsettings. Although data in federated settings is often non-i.i.d. across clients, the success ofcentralized deep learning suggests that data often shares a global feature representation , whilethe statistical heterogeneity across clients or tasks is concentrated in the labels . Based onthis intuition, we propose a novel federated learning framework and algorithm for learning ashared data representation across clients and unique local heads for each client. Our algorithmharnesses the distributed computational power across clients to perform many local-updateswith respect to the low-dimensional local parameters for every update of the representation.We prove that this method obtains linear convergence to the ground-truth representationwith near-optimal sample complexity in a linear setting, demonstrating that it can efficientlyreduce the problem dimension for each client. Further, we provide extensive experimentalresults demonstrating the improvement of our method over alternative personalized federatedlearning approaches in heterogeneous settings. ∗ Department of Electrical and Computer Engineering, The University of Texas at Austin, Austin, TX, USA. { [email protected], [email protected], [email protected] } . † Department of Electrical and Systems Engineering, University of Pennsylvania, Philadelphia, PA, USA. { [email protected] } . a r X i v : . [ c s . L G ] F e b Introduction
Many of the most heralded successes of modern machine learning have come in centralized settings, wherein a single model is trained on a large amount of centrally-stored data. Thegrowing number of data-gathering devices, however, calls for a distributed architecture to trainmodels. Federated learning aims at addressing this issue by providing a platform in whicha group of clients collaborate to learn effective models for each client by leveraging the localcomputational power, memory, and data of all clients [McMahan et al., 2017]. The task ofcoordinating between the clients is fulfilled by a central server that combines the models receivedfrom the clients at each round and broadcasts the updated information to them. Importantly, theserver and clients are restricted to methods that satisfy communication and privacy constraints,preventing them from directly applying centralized techniques.However, one of the most important challenges in federated learning is the issue of data hetero-geneity , where the underlying data distribution of client tasks could be substantially differentfrom each other. In such settings, if the server and clients learn a single shared model (e.g., byminimizing average loss), the resulting model could perform poorly for many of the clients in thenetwork (and also not generalize well across diverse data [Jiang et al., 2019]). In fact, for someclients, it might be better to simply use their own local data (even if it is small) to train a localmodel; see Figure 1. Finally, the (federated) trained model may not generalize well to unseenclients that have not participated in the training process. These issues raise this question:“
How can we exploit the data and computational power of all clients in data heteroge-neous settings to learn a personalized model for each client? ”We address this question by taking advantage of the common representation among clients.Specifically, we view the data heterogeneous federated learning problem as n parallel learningtasks that they possibly have some common structure, and our goal is to learn and exploit thiscommon representation to improve the quality of each client’s model . Indeed, this would bein line with our understanding from centralized learning, where we have witnessed success intraining multiple tasks simultaneously by leveraging a common (low-dimensional) representationin popular machine learning tasks (e.g., image classification, next-word prediction) [Bengio et al.,2013, LeCun et al., 2015]. Main Contributions.
We introduce a novel federated learning framework and an associatedalgorithm for data heterogeneous settings. Next, we present our main contributions.(i)
FedRep Algorithm.
Federated Representation Learning (FedRep) leverages the fullquantity of data stored across clients to learn a global low-dimensional representationusing gradient-based updates. Further, it enables each client to compute a personalized,low-dimensional classifier, which we term as the client’s local head, that accounts for theunique labeling of each client’s local data.(ii)
Convergence Rate.
We show that FedRep converges to the optimal representation at a exponentially fast rate with near-optimal sample complexity in the case that each clientaims to solve a linear regression problem with a two-layer linear neural network. Our2 .25*d 0.5*d dNumber of training samples/user10 A v e r a g e M S E Local MSE for d = 20, k = 2, n = 100 Local OnlyFedAvgFedRep
Figure 1: Local only training suffers in small-training data regimes, whereas training a singleglobal model (FedAvg) cannot overcome client heterogeneity even when the number of trainingsamples is large. FedRep exploits a common representation of the clients to achieve small errorin all cases.analysis further implies that we only need O ( κ k ( k log( rn ) + drn )) samples per client,where n is the number of clients, d is the dimension of the data, k is the representationdimension, r is the participation rate and κ is the condition number of the ground-truthclient-representation matrix.(iii) Empirical Results.
Through a combination of synthetic and real datasets (CIFAR10,CIFAR100, FEMNIST, Sent140) we show the benefits of FedRep in: (a) leveraging manylocal updates, (b) robustness to different levels of heterogeneity, and (c) generalization tonew clients. We consider several important baselines including FedAvg [McMahan et al.,2017], Fed-MTL [Smith et al., 2017], LG-FedAvg [Liang et al., 2020], and Per-FedAvg[Fallah et al., 2020]. Our experiments indicate that FedRep outpeforms these baselines inheterogeneous settings that share a global representation.
Benefits of FedRep.
Next, we list benefits of FedRep over standard federated learning (thatlearns a single model). (I) More local updates.
By reducing the problem dimension, each client can make many localupdates at each communication round, which is beneficial in learning its own individual head.This is unlike standard federated learning where multiple local updates in a heterogeneous settingmoves each client away from the best averaged representation, and thus hurts performance. (II) Gains of cooperation.
Denote d to be the data dimension and n the number of clients. Fromour sample complexity bounds, it follows that with FedRep, the sample complexity per clientscales as Θ(log( n ) + d / n ). On the other hand, local learning (without any collaboration) has asample complexity that scales as Θ( d ) . Thus, if 1 (cid:28) n (cid:28) e Θ( d ) (see Section 4.2 for details), weexpect benefits of collaboration through federation. When d is large (as is typical in practice), e Θ( d ) is exponentially larger, and federation helps each client. To the best of our knowledge, thisis the first sample-complexity-based result for heterogeneous federated learning that demonstratesthe benefit of cooperation. III) Generalization to new clients.
For a new client, since a ready-made representation is available,the client only needs to learn a head with a low-dimensional representation of dimension k . Thus,its sample complexity scales only as Θ( k log( / (cid:15) )) to have no more than (cid:15) error in accuracy. Related Work.
A variety of recent works have studied personalization in federated learningusing, for example, local fine-tuning [Wang et al., 2019, Yu et al., 2020], meta-learning [Chenet al., 2018, Fallah et al., 2020, Jiang et al., 2019, Khodak et al., 2019], additive mixtures of localand global models [Deng et al., 2020, Hanzely and Richt´arik, 2020, Mansour et al., 2020], andmulti-task learning [Smith et al., 2017]. In all of these methods, each client’s subproblem is stillfull-dimensional - there is no notion of learning a dimensionality-reduced set of local parameters.More recently, Liang et al. [2020] also proposed a representation learning method for federatedlearning, but their method attempts to learn many local representations and a single global headas opposed to a single global representation and many local heads. Earlier, Arivazhagan et al.[2019] presented an algorithm to learn local heads and a global network body, but their localprocedure jointly updates the head and body (using the same number of updates), and theydid not provide any theoretical justification for their proposed method. Meanwhile, anotherline of work has studied federated learning in heterogeneous settings [Haddadpour et al., 2020,Karimireddy et al., 2020, Pathak and Wainwright, 2020, Reddi et al., 2020, Wang et al., 2020],and the optimization-based insights from these works may be used to supplement our formulationand algorithm.
The generic form of federated learning with n clients ismin ( q ,...,q n ) ∈Q n n n (cid:88) i =1 f i ( q i ) , (1)where f i and q i are the error function and learning model for the i -th client, respectively, and Q n is the space of feasible sets of n models. We consider a supervised setting in which the data forthe i -th client is generated by a distribution ( x i , y i ) ∼ D i . The learning model q i : R d → Y mapsinputs x i ∈ R d to predicted labels q i ( x i ) ∈ Y , which we would like to resemble the true labels y i .The error f i is in the form of an expected risk over D i , namely f i ( q i ) := E ( x i ,y i ) ∼D i [ (cid:96) ( q i ( x i ) , y i )],where (cid:96) : Y × Y → R is a loss function that penalizes the distance of q i ( x i ) from y i .In order to minimize f i , the i -th client accesses a dataset of M i labelled samples { ( x ji , y ji ) } M i j =1 from D i for training. Federated learning addresses settings in which the M i ’s are typically smallrelative to the problem dimension while the number of clients n is large. Thus, clients may not beable to obtain solutions q i with small expected risk by training completely locally on only their M i local samples. Instead, federated learning enables the clients to cooperate, by exchangingmessages with a central server, in order to learn models using the cumulative data of all theclients.Standard approaches to federated learning aim at learning a single shared model q = q = · · · = q n that performs well on average across the clients [Li et al., 2018, McMahan et al., 2017]. In thisway, the clients aim to solve a special version of Problem (1), which is to minimize (1 /n ) (cid:80) i f i ( q )4 t client n client 1 t = 1 rn X i t ti x t t +11 t +1 n x n t } } h tn h t t } } server Figure 2: Federated representation learning structure where clients and the server aim at learninga global representation φ together, while each client i learns its unique head h i locally.over the choice of the shared model q . However, this approach may yield a solution that performspoorly in heterogeneous settings where the data distributions D i vary across the clients. Indeed,in the presence of data heterogeneity, the error functions f i will have different forms and theirminimizers are not the same. Hence, learning a shared model q may not provide good solution toProblem (1). This necessities the search for more personalized solutions { q i } that can be learnedin a federated manner using the clients’ data. Learning a Common Representation.
We are motivated by insights from centralized machinelearning that suggest that heterogeneous data distributed across tasks may share a commonrepresentation despite having different labels [Bengio et al., 2013, LeCun et al., 2015]; e.g., sharedfeatures across many types of images, or across word-prediction tasks. Using this common(low-dimensional) representation, the labels for each client can be simply learned using a linearclassifier or a shallow neural network.Formally, we consider a setting consisting of a global representation φ : R d → R k , which mapsdata points to a lower space of size k , and client-specific heads h i : R k → Y . The modelfor the i -th client is the composition of the client’s local parameters and the representation: q i ( x ) = ( h i ◦ φ )( x ). Critically, k (cid:28) d , meaning that the number of parameters that must belearned locally by each client is small. Thus, we can assume that any client’s optimal classifierfor any fixed representation is easy to compute, which motivates the following re-written globalobjective: min φ ∈ Φ n n (cid:88) i =1 min h i ∈H f i ( h i ◦ φ ) , (2)where Φ is the class of feasible representations and H is the class of feasible heads. In ourproposed scheme, clients cooperate to learn the global model using all clients’ data, while theyuse their local information to learn their personalized head. We discuss this in detail in Section 3.5 .1 Comparison with Standard Federated Learning To formally demonstrate the advantage of our formulation over the standard (single-model)federated learning formulation in heterogeneous settings with a shared representation, we studya linear representation setting with quadratic loss. As we will see below, standard federatedlearning cannot recover the underlying representation in the face of heterogeneity , while ourformulation does indeed recover it.Consider a setting in which the functions f i are quadratic losses, the representation φ is aprojection onto a k -dimensional subspace of R d given by matrix B ∈ R d × k , and the i -th client’slocal head h i is a vector w i ∈ R k . In this setting, we model the local data of clients {D i } i suchthat y i = w ∗ i (cid:62) B ∗(cid:62) x i for some ground-truth representation B ∗ ∈ R d × k and local heads w ∗ i ∈ R k .This setting will be described in detail in Section 4. In particular, one can show that the expectederror over the data distribution D i has the following form: f i ( w i ◦ B ) := (cid:107) Bw i − B ∗ w ∗ i (cid:107) .Consequently, Problem (2) becomesmin B ∈ R d × k , w i ,..., w n ∈ R k n n (cid:88) i =1 (cid:107) Bw i − B ∗ w ∗ i (cid:107) . (3)In contrast, standard federated learning methods, which aim to learn a shared model ( B , w ) forall the clients, solve min B ∈ R d × k , w ∈ R k n n (cid:88) i =1 (cid:107) Bw − B ∗ w ∗ i (cid:107) . (4)Let ( ˆ B , { ˆ w i } i ) denote a global minimizer of (3). We thus have ˆ B ˆ w i = B ∗ w ∗ i for all i ∈ [ n ]. Also, itis not hard to see that ( B (cid:5) , w (cid:5) ) is a global minimizer of (4) if and only if B (cid:5) w (cid:5) = B ∗ ( n (cid:80) ni =1 w ∗ i ).Thus, our formulation finds an exact solution with zero global error, whereas standard federatedlearning has global error of n (cid:80) ni =1 (cid:107) n B ∗ (cid:80) ni (cid:48) =1 ( w ∗ i (cid:48) − w ∗ i ) (cid:107) , which grows with the heterogeneityof the w ∗ i . Moreover, since solving our formulation provides n matrix equations, we can fullyrecover the column space of B ∗ as long as w ∗ i ’s span R k . In contrast, solving (4) yields only onematrix equation, so there is no hope to recover the column space of B ∗ for any k > FedRep
Algorithm
FedRep solves Problem (2) by distributing the computation across clients. The server and clientsaim to learn the global representation φ together, while the i -th client aims to learn its uniquelocal head denoted by h i locally (see Figure 2). To do so, FedRep alternates between clientupdates and a server update on each communication round. Client Update.
On each round, a constant fraction r ∈ (0 ,
1] of the clients are selected toexecute a client update. In the client update, client i makes τ local gradient-based updates tosolve for its optimal head given the current global representation φ t communicated by the server.Namely, for s = 1 , . . . , τ , client i updates its head as follows: h t,s +1 i = GRD ( f i ( h t,si ◦ φ t ) , h t,si , α ) , lgorithm 1 FedRep
Parameters:
Participation rate r , step sizes α, η ; number of local updates τ ; number ofcommunication rounds T ; initial φ , h , . . . , h n . for t = 1 , , . . . , T do Server receives a batch of clients I t of size rn Server sends current representation φ t to these clients for each client i in B do Client i initializes h ti ← h t − ,τi Client i makes τ updates to its head h ti : for s = 1 to τ do h t,si ← GRD ( f i ( h t,si ◦ φ t ) , h t,si , α ) end for Client i locally updates the representation as: φ t +1 i ← GRD ( f i ( h t,τi ◦ φ t ) , φ t , α )Client i sends updated representation φ t +1 i to server end forfor each client i not in B , do Set h t,τi ← h t − ,τi end for Server computes the new representation as φ t +1 = rn (cid:80) i ∈I t φ t +1 i end for where GRD ( f, h, α ) is generic notation for an update of the variable h using a gradient of function f with respect to h and the step size α . For example, GRD ( f i ( h t,si ◦ φ t ) , h t,si , α ) can be a step ofgradient descent, stochastic gradient descent (SGD), SGD with momentum, etc. The key is thatclient i makes many such local updates, i.e., τ is large, to find the optimal head based on itslocal data, given the most recent representation φ t received from the server. Server Update.
Once the local updates with respect to the head h i finish, the client participatesin the server update by taking one local gradient-based update with respect to the currentrepresentation, i.e., computing φ t +1 i ← GRD ( f i ( h t,τi ◦ φ t ) , φ t , α ) . It then sends φ t +1 i to the server, which averages the local updates to compute the next represen-tation φ t +1 . The entire procedure is outlined in Algorithm 1. In this section, we analyze an instance of Problem (2) with quadratic loss functions and linearmodels, as discussed in Section 2.1. Here, each client’s problem is to solve a linear regressionwith a two-layer linear neural network. In particular, each client i attempts to find a sharedglobal projection onto a low-dimension subspace B ∈ R d × k and a unique regressor w i ∈ R k that together accurately map its samples x i ∈ R d to labels y i ∈ R . The matrix B corresponds7o the representation φ , and w i corresponds to local head h i for the i -th client. We thus have h i ◦ φ ( x i ) = w (cid:62) i B (cid:62) x i . Hence, the loss function for client i is given by: f i ( w i ◦ B ) := E ( x i ,y i ) ∼D i (cid:104) ( y i − w (cid:62) i B (cid:62) x i ) (cid:105) (5)meaning that the global objective is:min B ∈ R d × k W ∈ R n × k F ( B , W ) := 12 n n (cid:88) i =1 E ( x i ,y i ) (cid:104) ( y i − w (cid:62) i B (cid:62) x i ) (cid:105) , (6)where W = [ w (cid:62) , . . . , w (cid:62) n ] ∈ R n × k is the concatenation of client-specific heads. To evaluate theability of FedRep to learn an accurate representation, we model the local datasets {D i } i suchthat, for i = 1 . . . , n y i = w ∗ i (cid:62) B ∗(cid:62) x i , for some ground-truth representation B ∗ ∈ R d × k and local heads w ∗ i ∈ R k –i.e. a standardregression setting. In other words, all of the clients’ optimal solutions live in the same k -dimensional subspace of R d , where k is assumed to be small. We consider the case in which y i contains no noise for simplicity, as adding noise does not significantly change our analysis.Moreover, we make the following standard assumption on the samples x i . Assumption 1. (Sub-gaussian design) The samples x i ∈ R d are i.i.d. with mean , covariance I d , and are I d -sub-gaussian, i.e. E [ e v (cid:62) x i ] ≤ e (cid:107) v (cid:107) / for all v ∈ R d . We next discuss how FedRep tries to recover the optimal representation in this setting.
Client Update.
As in Algorithm 1, rn clients are selected on round t to update their currentlocal head w ti and the global representation B t . Each selected client i samples a fresh batch { x t,ji , y t,ji } mj =1 of m samples according to its local data distribution D i to use for updating bothits head and representation on each round t that it is selected. That is, within the round, client i considers the batch loss ˆ f ti ( w ti ◦ B t ) := 12 m m (cid:88) j =1 ( y t,ji − w t (cid:62) i B t (cid:62) x t,ji ) . (7)Since ˆ f ti is strongly convex with respect to w ti , the client can find an update for a local headthat is (cid:15) -close to the global minimizer of (7) after at most O (log(1 /(cid:15) )) local gradient updates.Alternatively, since the function is also quadratic, the client can solve for the optimal w directlyin only O ( mk + k ) operations. Thus, to simplify the analysis we assume each selected clientobtains w t +1 i = argmin w ˆ f ti ( w ◦ B t ) during each round of local updates. Server Update.
After updating its head, client i updates the global representation with onestep of gradient descent using the same m samples and sends the update to the server, as outlinedin Algorithm 2. Then, the server computes the new representation by averaging over receivedrepresentations. 8 lgorithm 2 FedRep for linear regression
Input:
Step size η ; number of rounds T , participation rate r , initial B . for t = 1 , , . . . , T do Server receives a subset I t of clients of size rn Server sends current representation B t to these clients for i ∈ I t do Client i samples a fresh batch of m samples.Client i updates w i : w t +1 i ← argmin w ˆ f ti ( w ◦ B t )Client i updates representation: B t +1 i ← B t − η ∇ B ˆ f ti ( w t +1 i ◦ B t )Client i sends B t +1 i to the server. end for Server averages updates: B t +1 ← rn (cid:80) i ∈I t B t +1 i end for As mentioned earlier, in FedRep, each client i performs an alternating minimization-descentmethod to solve its nonconvex objective in (7). This means the global loss over all clients atround t is given by 1 n n (cid:88) i =1 ˆ f ti ( w ti ◦ B t ) := 12 mn n (cid:88) i =1 m (cid:88) j =1 ( y t,ji − w t (cid:62) i B t (cid:62) x t,ji ) . (8)This objective has many global minima, including all pairs of matrices ( Q − W ∗ , B ∗ Q (cid:62) ) where Q ∈ R k × k is invertible, eliminating the possibility of exactly recovering the ground-truth factors( W ∗ , B ∗ ). Instead, the ultimate goal of the server is to recover the ground-truth representation ,i.e., the column space of B ∗ . To evaluate how closely the column space is recovered, we definethe distance between subspaces as follows. Definition 1.
The principal angle distance between the column spaces of B , B ∈ R d × k is givenby dist ( B , B ) := (cid:107) ˆB (cid:62) , ⊥ ˆB (cid:107) , (9) where ˆB , ⊥ and ˆB are orthonormal matrices satisfying span ( ˆB , ⊥ ) = span ( B ) ⊥ and span ( ˆB ) = span ( B ) . Next, we define a key property required for our results.
Definition 2.
A rank- k matrix M ∈ R d × d is µ -row-wise incoherent if max i ∈ [ d ] (cid:107) m i (cid:107) ≤ µ √ d √ d (cid:107) M (cid:107) F , where m i ∈ R d is the i -th row of M . Incoherence of the ground-truth matrices is a key property required for efficient matrix completionand other sensing problems with sparse measurements [Chi et al., 2019]. Since our measurement9atrices are row-wise sparse, we require the row-wise incoherence of W ∗ B ∗(cid:62) , which the followingassumption implies with µ = 1. Assumption 2. (Client normalization) The ground-truth client-specific parameters satisfy (cid:107) w ∗ i (cid:107) = √ k for all i ∈ [ n ] , and B ∗ has orthonormal columns. Assumption 3. (Client diversity) Let ¯ σ min , ∗ be the minimum singular value of any matrix √ rn W ∈ R rn × k with rows being an rn -sized subset of ground-truth client-specific parameters { w ∗ , . . . , w ∗ n } . Then ¯ σ min , ∗ > . Assumption 3 states that if we select any rn clients, their optimal solutions span R k . Indeed,this assumption is weak as we expect the number of participating clients rn to be substantiallylarger than k . Note that if we do not have client solutions that span R k , recovering B ∗ would beimpossible because the samples ( x ji , y ji ) may never contain any information about one or morefeatures of B ∗ .Our main result shows that the iterates { B t } t generated by FedRep in this setting linearlyconverge to the optimal representation B ∗ in principal angle distance. Theorem 1.
Suppose Assumptions 2 and 3 hold. Let ¯ σ min , ∗ be defined as in Assumption3, and similarly define ¯ σ max , ∗ as the maximum singular value. Also define κ := ¯ σ max , ∗ ¯ σ min , ∗ and E := 1 − dist ( B , B ∗ ) . Suppose that m ≥ c ( κ k log( rn ) /E + κ k d/ ( E rn )) for some absoluteconstant c . Then for any t and any η ≤ / (4¯ σ , ∗ ) , we havedist ( B T , B ∗ ) ≤ (cid:0) − ηE ¯ σ , ∗ / (cid:1) T/ dist ( B , B ∗ ) , (10) with probability at least − T e −
100 min( k log( rn ) ,d ) . From Assumption 3, we have that ¯ σ , ∗ >
0, so the RHS of (10) strictly decreases with T forappropriate step size. Considering the complexity of m and the fact that the algorithm convergesexponentially fast, the total number of samples required per client to reach an (cid:15) -accurate solutionin principal angle distance is Θ ( m log ( κ / (cid:15)E )), which isΘ (cid:0)(cid:2) κ k ( k log( rn ) + d / rn ) (cid:3) log ( κ / (cid:15)E ) (cid:1) . (11)Next, a few remarks about this sample complexity follow. When and whom does federation help?
Observe that for a single client with no collaboration,the sample complexity scales as Θ( d ) . With FedRep, however, the sample complexity scales asΘ(log( rn ) + d / rn ) . Thus, so long as log( rn ) + d / rn (cid:28) d, federation helps. This indeed holds inseveral settings, for instance when 1 (cid:28) n (cid:28) e Θ( d ) . In practical scenarios, d (the data dimension)is large, and thus e Θ( d ) is exponentially larger; therefore collaboration helps each individualclient. Furthermore, from the point of view of a new client who enters the system later, it hasa representation available for free, and this new client’s sample complexity for adapting to itstask is only k log( / (cid:15) ). Thus, both the overall system benefits (a representation has been learned,which is useful for the new client because it now only needs to learn a head), and each individualclient that did take part in the federated training also benefits.10 onnection to matrix sensing. We would like to mention that the problem in (6) has a closeconnection with the matrix sensing problem, and in fact it can be considered as an instanceof that problem; see the proof in Appendix B for more details. Considering this connection,our theoretical results also contribute to the theoretical study of matrix sensing. Althoughmatrix sensing is a well-studied problem, our setting presents two new analytical challenges:(i) due to row-wise sparsity in the measurements, the sensing operator does not satisfy thecommonly-used Restricted Isometry Property within an efficient number of samples, i.e., itdoes not efficiently concentrate to an identity operation on all rank- k matrices, and (ii) FedRepexecutes a novel non-symmetric procedure. We further discuss these challenges in Appendix B.1.To the best of our knowledge, Theorem 1 provides the first convergence result for an alternatingminimization-descent procedure to solve a matrix sensing problem. It is also the first resultto show sample-efficient linear convergence of any solution to a matrix sensing with rank-one,row-wise sparse measurements. The state-of-the-art result for the closest matrix sensing settingto ours is given by Zhong et al. [2015] for rank-1, independent Gaussian measurements, whichour result matches up to an O ( κ ) factor. However, our setting is more challenging as we haverank-1 and row-wise sparse measurements, and dependence on κ has been previously observedin settings with sparse measurements, e.g. matrix completion [Jain et al., 2013]. New users and dimensionality reduction.
Theorem 1 is related to works studying repre-sentation learning in the context of multi-task learning. [Tripuraneni et al., 2020] and [Du et al.,2020] provided upper bounds on the generalization error resulting from learning a low-dimensionalrepresentation of tasks assumed to share a common representation. They show that, if thecommon representation is learned, then excess risk bound on a new task is O ( C (Φ) nm + km new ),where C (Φ) is the complexity of the representation class Φ and m new is the number of labelledsamples from the new task that the learner can use for fine-tuning. Since the number of testsamples must exceed only O ( k ), where k is assumed to small, these works demonstrate thedimensionality-reducing benefits of representation learning. Our work complements these resultsby showing how to provably and efficiently learn the representation in the linear case. We focus on three points in our experiments: (i) the effect of many local updates for the localhead in FedRep (ii) the quality of the global representation learned by FedRep and (iii) theapplicability of FedRep to a wide range of datasets. Further experimental details are provided inthe appendix.
We start by experimenting with an instance of the multi-linear regression problem analyzed inSection 4. Consistent with this formulation, we generate synthetic samples x ji ∼ N (0 , I d ) andlabels y ji ∼ N ( w ∗ (cid:62) i ˆ B ∗ (cid:62) x ji , − ) (here we include an additive Gaussian noise). The ground-truthheads w ∗ i ∈ R k for clients i ∈ [ n ] and the ground-truth representation ˆ B ∗ ∈ R d × k are generatedrandomly by sampling and normalizing Gaussian matrices.11
100 200 t n = 10 GD-GD10GD-GDFedRep 0 100 200 t n = 100 t n = 1000 Figure 3: Comparison of (principal angle) distances between the ground-truth and estimatedrepresentations by FedRep and alternating gradient descent algorithms for different numbers ofclients n . In all plots, d = 10, k = 2, m = 5, and r = 0 . Benefit of finding the optimal head.
We first demonstrate that the convergence of FedRepimproves with larger number of clients n , making it highly applicable to federated settings.Further, we give evidence showing that this improvement is augmented by the minimization stepin FedRep, since methods that replace the minimization step in FedRep with 1 and 10 steps ofgradient descent (GD-GD and 10GD-GD, respectively) do not scale properly with n . In Figure 3,we plot convergence trajectories for FedRep, GD-GD, and 10GD-GD for four different values of n and fixed m, d, k and r . As we observe in Figure 3, by increasing the number of nodes n , clientsconverge to the true representation faster. Also, running more local updates for finding the localhead accelerates the convergence speed of FedRep. In particular, FedRep which exactly finds theoptimal local head at each round has the fastest rate compared to GD-GD and 10GD-GD thatonly run 1 and 10 local updates, respectively, to learn the head. Generalization to new clients.
Next, we evaluate the effectiveness of the representationlearned by FedRep in reducing the sample complexity for a new client which has not participatedin training. We first train FedRep and FedAvg on a fixed set of n = 100 clients as in Figure 1.The new client has access to m new labelled local samples. It will use the representation ˆ B ∗ ∈ R d × k learned by training clients, where ( d, k ) = (20 , m new labelled samples from the new client (Local Only) in Figure 4. The large error forFedAvg demonstrates that it does not learn the ground-truth representation. Meanwhile, therepresentation learned by FedRep allows an accurate model to be found for the new client aslong as m new ≥ k , which drastically improves over the complexity for Local Only ( m new = Ω( d )). We next investigate whether these insights apply to nonlinear models and real datasets. Additionalresults and details of our experimental setup are provided in the appendix.
Datasets and Models.
We use four real datasets: CIFAR10 and CIFAR100 [Krizhevsky et al.,12 A v e r a g e M S E o n N e w C li e n t w n e w New Client MSE for d = 20, k = 2, m = 10, r = 0.1 Local OnlyFedAvgFedRep
Figure 4: MSE on new clients sharing the representation after fine-tuning using various numbersof samples from the new client.2009], FEMNIST [Caldas et al., 2018, Cohen et al., 2017] and Sent140 [Caldas et al., 2018]. Thefirst three are image datasets and the last is a text dataset for which the goal is to classify thesentiment of a tweet as positive or negative. We control the heterogeneity of CIFAR10 andCIFAR100 by assigning different numbers S of classes per client, from among 10 and 100 classes,respectively. We use balanced versions of these datasets, meaning that each client has 50,000/ n training samples. For FEMNIST, we first restrict the dataset to 10 handwritten letters andassign samples to clients according to a log-normal distribution as in [Li et al., 2019]. We considera partition of n = 150 clients with an average of 148 samples/client. For Sent140, we use thenatural assignment of tweets to their author, and use n = 183 clients with an average of 72samples per client. We use a 5-layer CNN for the CIFAR datasets, a 2-hidden layer multilayerperceptron (MLP) for FEMNIST, and an RNN for Sent140 (details provided in the appendix). Baselines.
We compare against a variety of baselines: (1) GD-GD, which simultaneously updatesa representation and unique local head in each local update as in [Arivazhagan et al., 2019];(2) FedAvg [McMahan et al., 2017], which tries to learn a single global model; (3) Fed-MTL[Smith et al., 2017], which learns local models and a regularizer to encode relationships amongthe clients; (4) LG-FedAvg [Liang et al., 2020], which learns local representations and a globalhead; and (5) local only training with no global communication. Method (2) is the standardfederated learning technique, (3) is a well-known method for achieving personalized models, and(1) and (4) are methods with similar intuitions as ours.
Implementation.
In each experiment we sample 10% of the clients on every round. We initializeall models randomly and train for T = 200 communication rounds. Unless otherwise noted,FedRep executes multiple local epochs of SGD with momentum to train the local head, followedby one epoch for the representation, in each local update. All other methods use E = 1 localepoch (for all parameters) besides FedAvg, which uses E = 10. Benefit of more local updates.
As mentioned in Section 1, a key advantage of our formulationis that it enables clients to run many local updates without causing divergence from the globaloptimal solution. We demonstrate an example of this in Figure 5. Here, there are n = 100 clients13able 1: Average local test accuracy over the final 10 rounds out of T = 200 total rounds, with95% confidence intervals. Dataset FedRep GD-GD FedAvg Fed-MTL LG-FedAvg LocalCIFAR10, S = 2, n = 100 86 . ± . . ± . . ± . . ± . . ± . . ± . S = 2, n = 1000 76 . ± . . ± . . ± . . ± . . ± . . ± . S = 5, n = 100 72 . ± . . ± . . ± . . ± . . ± . . ± . S = 5, n = 1000 53 . ± . . ± . . ± . . ± . . ± . . ± . S = 20, n = 100 38 . ± . . ± . . ± . . ± . . ± . . ± . S = 20, n = 500 21 . ± . . ± . . ± . . ± . . ± . . ± . n = 150 65 . ± . . ± . . ± . . ± . . ± . . ± . n = 183 72 . ± . . ± . . ± . . ± . . ± . . ± . A v e r a g e L o c a l T e s t A cc u r a c y CIFAR10, n = 100, 2 classes/client FedRep E =1FedRep E =10FedRep E =20FedRep E =40FedAvg E =1FedAvg E =10FedAvg E =20FedAvg E =40 Figure 5: CIFAR10 local test errors for different numbers of local epochs E for FedRep andFedAvg.where each has S = 2 classes of images. For FedAvg, we observe running more local updates doesnot necessarily improve the performance. In contrast, FedRep’s performance is monotonicallynon-decreasing with E , i.e., FedRep requires less tuning of E and is never hurt by more localcomputation. Robustness to varying numbers of samples/client and level of heterogeneity.
Weshow the average local test errors for all of the algorithms for a variety of settings in Table 1.In all cases, FedRep performs at least as well as the best method, up to statistical equivalence.Not surprisingly, the largest gain for FedRep over FedAvg comes in heterogeneous settings withmany samples/client: CIFAR10 and CIFAR100 with 2 and 20 classes/client, respectively, and n = 100. In these settings, FedRep matches, or nearly matches, the performance of local-onlytraining. Meanwhile, even in homogeneous and small-data settings conducive to global-onlytraining, i.e. S = 5 and n = 1000 on CIFAR10, FedRep still outperforms all the baselines, evenwithout any regularization added to prevent over-fitting to the head. This evidence supportsour intuition that FedRep effectively interpolates between local and global learning.
Furthermore,FedRep outperforms all other baselines. 14 L o c a l t e s t a cc u r a c y - F E M N I S T - d i g i t s Models trained on FEMNIST-letters
FedAvgPerFedAvgLG-FedAvgFedRep
Figure 6: Test accuracy on handwritten digits from FEMNIST after fine-tuning the head ofmodels trained on FEMNIST-letters.
Generalization to new clients.
We evaluate the strength of the representation learned byFedRep in terms of adaptation for new users. To do so, we consider an additional baseline,PerFedAvg [Fallah et al., 2020], which aims to find a single global model that can quickly adapt toa new task. First, we train FedRep, FedAvg, PerFedAvg, and LG-FedAvg in the usual setting onthe partition of FEMNIST containing images of 10 handwritten letters (FEMNIST-letters). Then,we encounter clients with data from a different partition of the FEMNIST dataset, containingimages of handwritten digits. We assume we have access to a dataset of 500 samples at thisnew client to fine tune the head. Using these, with each of the algorithms, we fine tune thehead over multiple epochs while keeping the representation fixed. In Figure 6, we repeatedlysweep over the same 500 samples over multiple epochs to further refine the head, and plot thecorresponding local test accuracy. As is apparent, FedRep has significantly better performancethan these baselines.
We introduce a novel representation learning framework and algorithm for federated learning,and we provide both theoretical and empirical justification for its utility in federated settings. Inparticular, our proposed framework exploits the structure of federating learning by (i) leveragingall clients’ data to learn a global representation that enhances each client’s model and couldpossibly generalize to new users and (ii) leveraging the computational power of clients to runmultiple local updates for learning their local heads. Future work remains to analyze theconvergence properties of FedRep in non-linear settings.15 ppendix
A Additional Experimental Results
A.1 Synthetic Data: Further comparison with GD-GD F ( W t , B t ) n = 5, m = 5 AltMinGDAltGD-GD0 100 200 t F ( W t , B t ) n = 5, m = 10 n = 10, m = 5 t n = 10, m = 10 n = 50, m = 5 t n = 50, m = 10 n = 100, m = 5 t n = 100, m = 10 n = 500, m = 5 t n = 500, m = 10 Figure 7: Function values for FedRep and GD-GD. The value of m is fixed in each row and n isfixed in each column. Here r = 1 (full participation) and the average trajectories over 10 trialsare plotted along with 95% confidence intervals. Principal angle distances are not plotted as theresults are very similar. We see that the relative improvement of FedRep over GD-GD increaseswith n , highlighting the advantage of FedRep in settings with many clients. Further experimental details.
In the synthetic data experiments, the ground-truth matrices W ∗ and B ∗ were generated by first sampling each element as an i.i.d. standard normal variable,then taking the QR factorization of the resulting matrix, and scaling it by √ k in the case of W ∗ .The clients each trained on the same m samples throughout the entire training process. Testsamples were generated identically as the training samples but without noise. Both the iteratesof FedRep and GD-GD were initialized with the SVD of the result of 10 rounds of projectedgradient descent on the unfactorized matrix sensing objective as in Algorithm 1 in [Tu et al.,2016]. We would like to note that FedRep exhibited the same convergence trajectories regardlessof whether its iterates were initialized with random Gaussian samples or with the projectedgradient descent procedure, whereas GD-GD was highly sensitive to its initialization, often notconverging when initialized randomly. A.2 Real Data: Comparison against additional baselines
Additional Results.
Extensions of Table 1 and Figure 6 to include additional baselines are givenin Table 2 and Figure 8. We compare against two recently proposed methods for personalizedfederated learning L2GD [Hanzely and Richt´arik, 2020] and APFL [Deng et al., 2020], as well asFedProx [Li et al., 2018], a well-known method for federated learning in heterogeneous settings.16able 2: Performance of additional baselines in the setting of Table 1.
Dataset FedRep L2GD APFL FedProxCIFAR10, S = 2, n = 100 86 . ± . . ± . . ± . . ± . S = 2, n = 1000 76 . ± . . ± . . ± . . ± . S = 5, n = 100 72 . ± . . ± . . ± . . ± . S = 5, n = 1000 53 . ± . . ± . . ± . . ± . S = 20, n = 100 38 . ± . . ± . . ± . . ± . S = 20, n = 500 21 . ± . . ± . . ± . . ± . Number of epochs of fine-funing the head L o c a l t e s t a cc u r a c y - F E M N I S T - d i g i t s Models trained on FEMNIST-letters
FedAvg-10PerFedAvgLG-FedAvgFedRepL2GDFedProxFedAvg-5APFL
Figure 8: Test accuracy on handwritten digits from FEMNIST after fine-tuning the head ofmodels trained on FEMNIST-letters.We see from Figure 8 that L2GD and APFL nearly match the performance of FedRep in settingswith a small number of clients and a large amount of data per client (recall that for the CIFARdatasets, the number of samples per client is 50 , /n ), but struggle in large-client, small-data-per-client environments. Meanwhile, in Figure 8, the L2GD global model is a relatively stronginitial model for fine-tuning on new clients, but still does not match FedRep. Conversely, theAPFL global model performs only as well as FedAvg with the same number (5) of local epochs(FedAvg-5), which is because the global model for APFL is essentially computed using FedAvg.Like FedAvg, FedProx performs best relative to FedRep in relatively homogeneous settings( S = 5 in Table 2) and learns an inferior representation on FEMNIST as suggested by Figure 8. Datasets.
The CIFAR10 and CIFAR100 datasets [Krizhevsky et al., 2009] were generated byrandomly splitting the training data into Sn shards with 50,000/( Sn ) images of a single classin each shard, as in [McMahan et al., 2017]. The full Federated EMNIST (FEMNIST) datasetcontains 62 classes of handwritten letters, but in Tables 1 and 2 we use a subset with only 10classes of handwritten letters. In particular, we followed the same dataset generation procedureas in [Li et al., 2019], but used 150 clients instead of 200. When testing on new clients as inFigures 6 and 8, we use samples from 10 classes of handwritten digits from FEMNIST, i.e., theMNIST dataset. In this phase there are 100 new clients, each with 500 samples from 5 differentclasses for fine-tuning. The fine-tuned models are then evaluated on 100 testing samples from17hese same 5 classes. For Sent140, we randomly sample 183 clients (Twitter users) that eachhave at least 50 samples (tweets). Each tweet is either positive sentiment or negative sentiment.Statistics of both the FEMNIST and Sent140 datasets we use are given in Table 3. For bothFEMNIST and Sent140 we use the LEAF framework [Caldas et al., 2018]. Implementations.
All experiments were performed on a 3.7GHz, 6-core Intel Corp i7-8700KCPU and the code was written in PyTorch. All methods use SGD with momentum with parameter0 .
5. For CIFAR10 and CIFAR100, the sample batch size is 50, for FEMNIST it is 10, and forSent140 it is 4. The participation rate r is always 0.1, besides in the fine-tuning phases in Figures6 and 8, in which all clients are sampled in each round.In Tables 1 and 2, we initialize all methods randomly and train for T = 200 communicationrounds. The accuracy shown is the average local test accuracy over the final ten communicationrounds. Intervals of 95% confidence are given over 3 random trials. Learning rates were adoptedfrom published results [Liang et al., 2020, McMahan et al., 2017, Smith et al., 2017], or, in thecase of Sent140, first tuned for FedAvg (resulting in a learning rate of 10 − ) then applied to allother methods. For FedRep, the same learning rate was used to update both local and globalparameters.We used the implementations of FedAvg, Fed-MTL and LG-FedAvg found at https://github.com/pliang279/LG-FedAvg/ corresponding to the paper [Liang et al., 2020]. As in the experi-ments in [Liang et al., 2020], we used a 5-layer CNN with two convolutional layers for CIFAR10and CIFAR100 followed by three fully-connected layers. For FEMNIST, we use an MLP withtwo hidden layers, and for Sent140 we use a pre-trained 300-dimensional GloVe embedding[Pennington et al., 2014] and train RNN with an LSTM module followed by a fully-connecteddecoding layer, as in [Caldas et al., 2018].For FedRep, we treated the head as the weights and biases in the final fully-connected layersin each of the models. For LG-FedAvg, we treated the first two convolutional layers and thefirst fully-connected layer of the model for CIFAR10 and CIFAR100 as the local representation,and the two final fully-connected layers as the global parameters. For FEMNIST, we set allparameters besides those in the final hidden layer and output layer to be the local representationparameters. For Sent140, we set the RNN module to be the local representation and the decoderto be the global parameters. Unlike in the paper introducing LG-FedAvg [Liang et al., 2020],we did not initialize the models for all methods with the solution of many rounds of FedAvg(instead, we initialized randomly) and we computed the local test accuracy as the average localtest accuracy over the final ten communication rounds, rather than the average of the maximumlocal test accuracy for each client over the entire training procedure. For PerFedAvg [Fallahet al., 2020] we used the first-order version (FO) and 10 local epochs for training.For L2GD [Hanzely and Richt´arik, 2020] we tuned the step size η over { , . , . } (from which0 . p = 0 .
9, thus the localparameters are trained on 10% of the communication rounds. We also set λ = 1 and E = 5.For APFL [Deng et al., 2020], we used the highest-performing constant value of α on CIFAR10according to the experiments in [Deng et al., 2020], namely α = 0 .
75. Like L2GD, we used E = 5local epochs, and the global models produced by both methods are used as the initial models forfine-tuning in Figure 8. For FedProx, we use E = 10 local epochs when n = 100 and E = 5 local18pochs otherwise. We tune the parameter µ from among { , . , . , . } , and use µ = 0 . Dataset Number of users ( n ) Avg samples/user Min samples/userFEMNIST 150 148 50Sent140 183 72 50 B Proof of Main Theoretical Result
B.1 Preliminaries.
Definition 3.
For a random vector x ∈ R d and a fixed matrix A ∈ R d × d , the vector A (cid:62) x iscalled (cid:107) A (cid:107) -sub-gaussian if y (cid:62) A (cid:62) x is sub-gaussian with sub-gaussian norm (cid:107) A (cid:107) (cid:107) y (cid:107) for all y ∈ R d , i.e. E [exp( y (cid:62) A (cid:62) x )] ≤ exp (cid:0) (cid:107) y (cid:107) (cid:107) A (cid:107) / (cid:1) . We use hats to denote orthonormal matrices (a matrix is called orthonormal if its set of columnsis an orthonormal set). By Assumption 2, the ground truth representation B ∗ is orthonormal, sofrom now on we will write it as ˆ B ∗ .For a matrix W ∈ R n × k and a random set of indices I ∈ [ n ] of cardinality rn , define W I ∈ R rn × k as the matrix formed by taking the rows of W indexed by I . Define ¯ σ max , ∗ :=max I∈ [ n ] , |I| = rn σ max ( √ rn W ∗I ) and ¯ σ min , ∗ := min I∈ [ n ] , |I| = rn σ min ( √ rn W ∗I ), i.e. the maximum andminimum singular values of any matrix that can be obtained by taking rn rows of √ rn W ∗ . Notethat by Assumption 2, each row of W ∗ has norm √ k , so √ rn acts as a normalizing factor suchthat (cid:107) √ rn W ∗I (cid:107) F = √ k . In addition, define κ = ¯ σ max , ∗ / ¯ σ min , ∗ .Let i now be an index over [ rn ], and let i (cid:48) be an index over [ n ]. For random batches ofsamples {{ ( x ji , y ji ) } mj =1 } rni =1 , define the random linear operator A : R rn × d → R rnm as A ( M ) =[ (cid:104) A ji , M (cid:105) ] ≤ i ≤ rn, ≤ j ≤ m ∈ R rnm . Here, A ji := e i ( x ji ) (cid:62) , where e i is the i -th standard vector in R rn ,and M ∈ R rn × d . Then, the loss function in (6) is equivalent tomin B ∈ R d × k , W ∈ R n × k { F ( B , W ) := 12 rnm E A , I (cid:104) (cid:107) Y − A ( W I B (cid:62) ) (cid:107) (cid:105) } , (12)where Y = A ( W ∗I ˆ B ∗ (cid:62) ) ∈ R rnm is a concatenated vector of labels. It is now easily seen that theproblem of recovering W ∗ ˆ B ∗ (cid:62) from finitely-many measurements A ( W ∗I ˆ B ∗ (cid:62) ) is an instance of19atrix sensing. Moreover, the updates of FedRep satisfy the following recursion: W t +1 I t = argmin W I t ∈ R rn × k rnm (cid:107)A t ( W ∗I t ˆ B ∗ (cid:62) − W I t B t (cid:62) ) (cid:107) (13) B t +1 = B t − ηrnm (cid:16) ( A t ) † A t ( W t +1 I t B t (cid:62) − W ∗I t ˆ B ∗ (cid:62) ) (cid:17) (cid:62) W t +1 I t (14)where A t is an instance of A , and ( A t ) † is the adjoint operator of A t , i.e. ( A t ) † A ( M ) = (cid:80) rni =1 (cid:80) mj =1 ( (cid:104) A t,ji , M (cid:105) ) A t,ji . Note that for the purposes of analysis, it does not matter how w t +1 i (cid:48) is computed for all i (cid:48) / ∈ I t , as these vectors do not affect the computation of B t +1 . Moreover, ouranalysis does not rely on any particular properties of the batches I , . . . , I T other than the factthat they have cardinality rn , so without loss of generality we assume I t = [ rn ] for all t = 1 , ...T and drop the subscripts I t on W t .We next discuss two analytical challenges involved in showing Theorem 1. (i) Row-wise sparse measurements. Recall that the measurement matrices A ti,j have non-zeroelements only in the i -th row. This property is beneficial in the sense that it allows for distributingthe sensing computation across the n clients. However, it also means that the operators { √ m A t } t do not satisfy Restricted Isometry Property (RIP), which therefore prevents us from usingstandard RIP-based analysis. For background, the RIP is defined as follows: Definition 4. (RIP) An operator B : R n × d → R nm satisfies the k -RIP with parameter δ k ∈ [0 , if and only if (1 − δ k ) (cid:107) M (cid:107) F ≤ (cid:107)B ( M ) (cid:107) ≤ (1 + δ k ) (cid:107) M (cid:107) F (15) holds simultaneously for all M ∈ R n × d of rank at most k . Claim 1.
Let A : R rn × d → R rnm such that A ( M ) = [ (cid:104) e i ( x ji ) (cid:62) , M (cid:105) ] ≤ i ≤ rn, ≤ j ≤ m , and let thesamples x ji be i.i.d. sub-gaussian random vectors with mean d and covariance I d . Then if m ≤ d/ , with probability at least − e − cd for some absolute constant c , √ m A does not satisfy1-RIP for any constant δ ∈ [0 , .Proof. Let M = e ( x ) (cid:62) . Then (cid:107) √ m A ( M ) (cid:107) = 1 m rn (cid:88) i =1 m (cid:88) j =1 (cid:104) e i ( x ji ) (cid:62) , e ( x ) (cid:62) (cid:105) = 1 m (cid:107) x (cid:107) + 1 m m (cid:88) j =2 (cid:104) x j , x (cid:105) ≥ m (cid:107) x (cid:107) (16)20lso observe that (cid:107) M (cid:107) F = (cid:107) x (cid:107) . Therefore, we have P (cid:13)(cid:13)(cid:13) √ m A ( M ) (cid:13)(cid:13)(cid:13) (cid:107) M (cid:107) F ≥ d m ≥ P (cid:32) m (cid:13)(cid:13) x (cid:13)(cid:13) (cid:13)(cid:13) x (cid:13)(cid:13) ≥ d m (cid:33) = P (cid:18)(cid:13)(cid:13) x (cid:13)(cid:13) ≥ d (cid:19) = 1 − P (cid:18) (cid:107) x (cid:107) − d ≤ − d (cid:19) ≥ − e − cd (17)where the last inequality follows for some absolute constant c by the sub-exponential property of (cid:107) x (cid:107) and the fact that E [ (cid:107) x (cid:107) ] = d . Thus, with probability at least 1 − e − cd , (cid:13)(cid:13)(cid:13) √ m A ( M ) (cid:13)(cid:13)(cid:13) ≥ d m (cid:107) M ) (cid:107) , meaning that √ m A does not satisfy 1-RIP with high probability if m ≤ d .Claim 1 shows that we cannot use the RIP to show O ( d/ ( rn )) sample complexity for m - instead,this approach would require m = Ω( d ). Fortunately, we do not need concentration of themeasurements for all rank- k matrices M , but only a particular class of rank- k matrices thatare row-wise incoherent . Leveraging the row-wise incoherence of the matrices being measuredallows us to show that we only require m = Ω( k log( rn ) + k d/ ( rn )) samples per user (ignoringdimension-independent constants). (ii) Non-symmetric updates. Existing analyses for nonconvex matrix sensing study algorithmswith symmetric update schemes for the factors W and B , either alternating minimization, e.g.[Jain et al., 2013], or alternating gradient descent, e.g. [Tu et al., 2016]. Here we show contractiondue to the gradient descent step in principal angle distance, differing from the standard result forgradient descent using Procrustes distance [Park et al., 2018, Tu et al., 2016, Zheng and Lafferty,2016]. We combine aspects of both types of analysis in our proof. B.2 Auxilliary Lemmas
We start by showing that we can assume without loss of generality that B t is orthonormalized atthe end of every communication round. Lemma 1.
Let W t ∈ R rn × k and B t ∈ R d × k denote the iterates of Algorithm 2 as outlined in (13) and (14) (with the subscript I t dropped). Now consider the modified algorithm given by thefollowing recursion: (cid:102) W t +1 = arg min W (cid:107)A ( W ( B t ) (cid:62) − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:107) F (18) (cid:101) B t +1 = B t − ηrnm (cid:16) ( A t ) † A t ( (cid:102) W t +1 ( B t ) (cid:62) − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:17) (cid:62) (cid:102) W t +1 (19) B t +1 = (cid:101) B t +1 ( (cid:101) R t +1 ) − (20) where B t +1 (cid:101) R t +1 is the QR factorization of (cid:101) B t +1 . Then the column spaces of B t and (cid:101) B t areequivalent for all t . roof. The proof follows a similar argument as Lemma 4.4 in Jain et al. [2013]. Assume thatthe claim holds for iteration t . Then there is some full-rank R B ∈ R k × k such that (cid:101) B t R B = B t .Then B t (cid:101) R t R B = B t , where (cid:101) R t R B is full rank. Since (cid:102) W t +1 = arg min W (cid:107)A t ( WB (cid:62) t − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:107) F = arg min W (cid:107)A t (( W ( (cid:101) R t R B ) −(cid:62) ) B (cid:62) t − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:107) F (21)we have that (cid:102) W t +1 ( (cid:101) R t R B ) −(cid:62) minimizes (cid:107)A t ( WB (cid:62) t − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:107) F over W since ( (cid:101) R t R B ) (cid:62) isfull rank. So W t +1 = (cid:102) W t +1 ( (cid:101) R t R B ) −(cid:62) and the column spaces of (cid:102) W t +1 and W t +1 are equivalent.Next, recall the definition of B t +1 : B t +1 = B t − ηrnm (cid:16) ( A ) † A t ( W t +1 B (cid:62) t − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:17) (cid:62) W t +1 (22)= B t (cid:101) R t R B − ηrnm (cid:16) ( A t ) † A t ( (cid:102) W t +1 ( (cid:101) R t R B ) −(cid:62) ( (cid:101) R t R B ) (cid:62) ( B t ) (cid:62) − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:17) (cid:62) (cid:102) W t +1 ( (cid:101) R t R B ) −(cid:62) = (cid:20) B t , − ηrnm (cid:16) ( A t ) † A t ( (cid:102) W t +1 ( B t ) (cid:62) − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:17) (cid:62) (cid:102) W t +1 (cid:21) (cid:34) (cid:101) R t R B ( (cid:101) R t R B ) −(cid:62) (cid:35) (23)so the column space of B t +1 is equal to the column space of (cid:20) B t , − ηrnm (cid:16) ( A t ) † A t ( (cid:102) W t +1 ( B t ) (cid:62) − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:17) (cid:62) (cid:102) W t +1 (cid:21) . Finally, note that (cid:101) B t +1 can be written as: (cid:101) B t +1 = (cid:20) B t , − ηrnm (cid:16) ( A t ) † A t ( (cid:102) W t +1 ( B t ) (cid:62) − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:17) (cid:62) (cid:102) W t +1 (cid:21) (cid:20) I k I k (cid:21) (24)so (cid:101) B t +1 has column space that is also equal to the column space of (cid:20) B t , − ηrnm (cid:16) A † A t ( (cid:102) W t +1 ( B t ) (cid:62) − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:17) (cid:62) (cid:102) W t +1 (cid:21) .Note that we cannot orthonormalize W t , neither in practice (due to privacy constraints) nor foranalysis only.In light of Lemma 1, we now analyze the modified algorithm in Lemma 1 in which B t isorthonormalized after each iteration. We will use our standard notation W t , B t to denote theiterates of this algorithm, with ˆB t being the orthonormalized version of B t . For clarity we restatethis modified algorithm with the standard notation here: W t +1 = arg min W rnm (cid:107)A t ( W ( ˆB t ) (cid:62) − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:107) F (25) B t +1 = ˆB t − ηrnm (cid:16) ( A t ) † A t ( W t +1 ( ˆB t ) (cid:62) − W ∗ ( ˆB ∗ ) (cid:62) ) (cid:17) (cid:62) W t +1 (26) ˆB t +1 = B t +1 ( R t +1 ) − (27)We next explicitly compute W t +1 . Since the rest of the proof analyzes a particular communicationround t , we drop superscripts t on the measurement operators A t and matrices A ti,j for ease ofnotation. 22 emma 2. In the modified algorithm, where B is orthonormalized after each update, the updatefor W is: W t +1 = W ∗ ˆB ∗ (cid:62) ˆB t − F (28) where F is defined in equation (33) below.Proof. We adapt the argument from Lemma 4.5 in [Jain et al., 2013] to compute the update for W t +1 , and borrow heavily from their notation.Let w t +1 p (respectively ˆb t +1 p ) be the p -th column of W t (respectively ˆB t ). Since W t +1 minimizes˜ F ( W , ˆB t ) := rnm (cid:107)A t ( W ∗ ( ˆB ∗ ) (cid:62) − W ( B t ) (cid:62) ) (cid:107) with respect to W , we have ∇ w p ˜ F ( W t +1 , ˆB t ) = for all p ∈ [ k ]. Thus, for any p ∈ [ k ], we have = ∇ w p ˜ F ( W t +1 , ˆB t )= 1 rnm rn (cid:88) i =1 m (cid:88) j =1 (cid:16) (cid:104) A i,j , W t +1 ( ˆB t ) (cid:62) − W ∗ ( ˆB ∗ ) (cid:62) (cid:105) (cid:17) A i,j ˆb tp = 1 rnm rn (cid:88) i =1 m (cid:88) j =1 k (cid:88) q =1 ( ˆb tq ) (cid:62) A (cid:62) i,j w t +1 q − k (cid:88) q =1 ( ˆb ∗ q ) (cid:62) A (cid:62) i,j w ∗ q A i,j ˆb tp This implies1 m k (cid:88) q =1 rn (cid:88) i =1 m (cid:88) j =1 A i,j ˆb tp ( ˆb tq ) (cid:62) A (cid:62) i,j w t +1 q = 1 m k (cid:88) q =1 rn (cid:88) i =1 m (cid:88) j =1 A i,j ˆb tp ( ˆb ∗ q ) (cid:62) A (cid:62) i,j w ∗ q (29)To solve for w t +1 , we define G , C , and D as rnk -by- rnk block matrices, as follows: G := G · · · G k ... . . . ... G k · · · G kk , C := C · · · C k ... . . . ... C k · · · C kk , D := D · · · D k ... . . . ... D k · · · D kk (30) where, for p, q ∈ [ k ]: G pq := m (cid:80) rni =1 (cid:80) mj =1 A i,j ˆb tp ˆb t (cid:62) q A (cid:62) i,j ∈ R rn × rn , C pq := m (cid:80) rni =1 (cid:80) mj =1 A i,j ˆb tp ( ˆb ∗ q ) (cid:62) A (cid:62) i,j ∈ R rn × rn , and, D pq := (cid:104) ˆb tp , ˆb ∗ q (cid:105) I rn ∈ R rn × rn . Recall that ˆb tp is the p -th column of ˆB t and ˆb ∗ q is the q -th column of ˆB ∗ . Further, define (cid:101) w t +1 = w t +11 ... w t +1 k ∈ R rnk , (cid:101) w ∗ = w ∗ ... w ∗ k ∈ R rnk . Then, by (29), we have (cid:101) w t +1 = G − C (cid:101) w ∗ = D (cid:101) w ∗ − G − ( GD − C ) (cid:101) w ∗ G conditioned on the event that its minimum singular value is strictlypositive, which Lemma 3 shows holds with high probability. Now consider the p -th block of (cid:101) w t +1 , and let ( GD − C ) w ∗ ) p denote the p -th block of ( GD − C ) w ∗ . We have (cid:101) w t +1 p = k (cid:88) q =1 (cid:104) ˆb tp , ˆb ∗ q (cid:105) w ∗ q − ( G − ( GD − C ) w ∗ ) p = k (cid:88) q =1 w ∗ q ( ˆb ∗ p ) (cid:62) ˆb tq − ( G − ( GD − C ) w ∗ ) p = (cid:16) W ∗ ( ˆB ∗ ) (cid:62) (cid:17) ˆb tq − ( G − ( GD − C ) w ∗ ) p (31)By constructing W t +1 such that the p -th column of W t +1 is w t +1 p for all p ∈ [ k ], we obtain W t +1 = W ∗ ˆB ∗ ( ˆB t ) (cid:62) − F (32)where F = [( G − ( GD − C ) (cid:101) w ∗ ) , . . . , ( G − ( GD − C ) (cid:101) w ∗ ) k ] (33)and ( G − ( GD − C ) (cid:101) w ∗ ) p is the p -th n -dimensional block of the nk -dimensional vector G − ( GD − C ) (cid:101) w ∗ .Next we bound the Frobenius norm of the matrix F , which requires multiple steps. First, weestablish some helpful notations. We drop superscripts indicating the iteration number t forsimplicity.Again let w ∗ be the rnk -dimensional vector formed by stacking the columns of W ∗ , and let ˆb p (respectively ˆb ∗ q ) be the p -th column of ˆB (respectively the q -th column of ˆB ∗ ). Recall that F can be obtained by stacking G − ( GD − C ) w ∗ into k columns of length n , i.e. vec( F ) = G − ( GD − C ) w ∗ . Further, G ∈ R rnk × rnk is a block matrix whose blocks G pq ∈ R rn × rn for p, q ∈ [ k ] are given by: G pq = 1 m rn (cid:88) i =1 m (cid:88) j =1 A i,j ˆB p ˆB (cid:62) q A (cid:62) i,j = 1 m rn (cid:88) i =1 m (cid:88) j =1 e i ( x ji ) (cid:62) ˆB p ˆB (cid:62) q x ji e (cid:62) i (34)So, each G pq is diagonal with diagonal entries( G pq ) ii = 1 m m (cid:88) j =1 ( x ji ) (cid:62) ˆB p ˆB (cid:62) q x ji = ˆB (cid:62) p (cid:32) m m (cid:88) j =1 x ji ( x ji ) (cid:62) (cid:33) ˆB q (35)Define Π i := m (cid:80) mj =1 x ji ( x ji ) (cid:62) for all i ∈ [ rn ]. Similarly as above, each block C pq of C is diagonalwith entries ( C pq ) ii = ˆB (cid:62) p Π i ˆB ∗ ,q (36)24nalogously to the matrix completion analysis in [Jain et al., 2013], we define the followingmatrices, for all i ∈ [ rn ]: G i := (cid:104) ˆB (cid:62) p Π i ˆB q (cid:105) ≤ p,q ≤ k = ˆB (cid:62) Π i ˆB , C i := (cid:104) ˆB (cid:62) p Π i ˆB ∗ ,q (cid:105) ≤ p,q ≤ k = ˆB (cid:62) Π i ˆB ∗ (37)In words, G i is the k × k matrix formed by taking the i -th diagonal entry of each block G pq ,and likewise for C i . Recall that D also has diagonal blocks, in particular D pq = (cid:104) ˆB p , ˆB ∗ q (cid:105) I d , thuswe also define D i := [ (cid:104) ˆB p , ˆB ∗ q (cid:105) ] ≤ p,q ≤ k = ˆB (cid:62) ˆB ∗ .Using this notation we can decouple G − ( GD − C ) w ∗ into i subvectors. Namely, let w ∗ i ∈ R k bethe vector formed by taking the (( p − rn + i )-th elements of w ∗ for p = 0 , ..., k −
1, and similarly,let f i be the vector formed by taking the (( p − rn + i )-th elements of G − ( GD − C ) w ∗ for p = 0 , ..., k −
1. Then f i = ( G i ) − ( G i D i − C i ) w ∗ i (38)is the i -th row of F . Now we control (cid:107) F (cid:107) F . Lemma 3.
Let δ k = c k / √ log( rn ) √ m for some absolute constant c , then (cid:107) G − (cid:107) ≤ − δ k with probability at least − e − k log( rn ) .Proof. We must lower bound σ min ( G ). For some vector z ∈ R rnk , let z i ∈ R k denote the vectorformed by taking the (( p − rn + i )-th elements of z for p = 0 , ..., k −
1. Since G is symmetric,we have σ min ( G ) = min z : (cid:107) z (cid:107) =1 z (cid:62) Gz = min z : (cid:107) z (cid:107) =1 rn (cid:88) i =1 ( z i ) (cid:62) G i z i = min z : (cid:107) z (cid:107) =1 rn (cid:88) i =1 ( z i ) (cid:62) ˆB (cid:62) Π i ˆBz i ≥ min i ∈ [ rn ] σ min ( ˆB (cid:62) Π i ˆB )Note that the matrix ˆB (cid:62) Π i ˆB can be written as follows: ˆB (cid:62) Π i ˆB = m (cid:88) j =1 √ m ˆB (cid:62) x ji ( 1 √ m ˆB (cid:62) x ji ) (cid:62) (39)Let v ji := √ m ˆB (cid:62) x ji for all i ∈ [ rn ] and j ∈ [ m ], and note that each v ji is i.i.d. √ m ˆB -sub-gaussian.Thus using the one-sided version of equation (4.22) (Theorem 4.6.1) in [Vershynin, 2018], we have σ min ( ˆB (cid:62) Π i ˆB ) ≥ − C (cid:32)(cid:114) km + r √ m (cid:33) (40)25ith probability at least 1 − e − r for m ≥ k and some absolute constant C . Choosing r suchthat δ k = C (cid:18)(cid:113) km + r √ m (cid:19) yields σ min ( ˆB (cid:62) Π i ˆB ) ≥ − δ k (41)with probability at least 1 − e − ( δ k √ m/C −√ k ) for m > k . Now, letting δ k = Ck / √ log( rn ) √ m , wehave that (41) holds with probability at least1 − exp (cid:18) − (cid:16) k / (cid:112) log( rn ) − √ k (cid:17) (cid:19) ≥ − exp (cid:16) − k (12 √ k (cid:112) log( rn ) − (cid:17) ≥ − exp (cid:0) k log( rn ) (cid:1) (42)Finally, taking a union bound over i ∈ [ n ] yields σ min ( G ) ≥ − δ k with probability at least1 − rn exp (cid:0) − k log( rn ) (cid:1) ≥ − e − k log( rn ) , (43)completing the proof. Lemma 4.
Let δ k = c k / √ log( rn ) √ m for some absolute constant c , then (cid:107) ( GD − C ) w ∗ (cid:107) ≤ δ k (cid:107) W ∗ (cid:107) dist ( ˆB t , ˆB ∗ ) with probability at least − e − k log( rn ) .Proof. For ease of notation we drop superscripts t . We define H = GD − C and H i := G i D i − C i = ˆB (cid:62) Π ˆB ˆB (cid:62) ˆB ∗ − ˆB (cid:62) Π ˆB ∗ = ˆB (cid:62) (cid:18) m X (cid:62) i X i (cid:19) ( ˆB ˆB (cid:62) − I d ) ˆB ∗ , (44)for all i ∈ [ rn ]. Then we have (cid:107) ( GD − C ) w ∗ (cid:107) = rn (cid:88) i =1 (cid:107) H i w i ∗ (cid:107) ≤ rn (cid:88) i =1 (cid:107) H i (cid:107) (cid:107) w ∗ i (cid:107) ≤ krn (cid:107) W ∗ (cid:107) rn (cid:88) i =1 (cid:107) H i (cid:107) (45)where the last inequality follows almost surely from Assumption 2 (the 1-row-wise incoherenceof W ∗ ) and the fact that krn = (cid:107) W ∗ (cid:107) F ≤ k (cid:107) W ∗ (cid:107) by Assumption 2 and the fact that W ∗ has rank k . It remains to bound rn (cid:80) rni =1 (cid:107) H i (cid:107) . Although (cid:107) H i (cid:107) is sub-exponential, as wewill show, (cid:107) H i (cid:107) is not sub-exponential, so we cannot directly apply standard concentrationresults. Instead, we compute a tail bound for each (cid:107) H i (cid:107) individually, then then union boundover i ∈ [ rn ]. Let U := √ m X i ( ˆB ˆB (cid:62) − I d ) ˆB ∗ , then the j -th row of U is given by u j = 1 √ m ˆB ∗ (cid:62) ( ˆB ˆB (cid:62) − I d ) x ji , √ m ˆB ∗ (cid:62) ( ˆB ˆB (cid:62) − I d )-sub-gaussian. Likewise, define V := √ m X i ˆB , then the j -th row of V is v j = 1 √ m ˆB (cid:62) x ji , therefore is √ m ˆB -sub-gaussian. We leverage the sub-gaussianity of the rows of U and V to makea similar concentration argument as in Proposition 4.4.5 in [Vershynin, 2018]. First, let S k − denote the unit sphere in k dimensions, and let N k be a -th net of cardinality |N k | ≤ k , whichexists by Corollary 4.2.13 in [Vershynin, 2018]. Next, using equation 4.13 in [Vershynin, 2018],we obtain (cid:107) ( ˆB ∗ ) (cid:62) ( ˆB ˆB (cid:62) − I d ) X (cid:62) i X i B (cid:107) = (cid:13)(cid:13)(cid:13) U (cid:62) V (cid:13)(cid:13)(cid:13) ≤ z , y ∈N k z (cid:62) (cid:16) U (cid:62) V (cid:17) y = 2 max z , y ∈N k z (cid:62) m (cid:88) j =1 u j v (cid:62) j y = 2 max z , y ∈N k m (cid:88) j =1 (cid:104) z , u j (cid:105)(cid:104) v j , y (cid:105) By definition of sub-gaussianity, (cid:104) z , u j (cid:105) and (cid:104) v j , y (cid:105) are sub-gaussian with norms √ m (cid:107) ˆB ∗ (cid:62) ( ˆB ˆB (cid:62) − I d ) (cid:107) = √ m dist( ˆB , ˆB ∗ ) and √ m (cid:107) ˆB (cid:107) = √ m , respectively. Thus for all j ∈ [ m ], (cid:104) z , u j (cid:105)(cid:104) v j , z (cid:105) is sub-exponential with norm cm dist( ˆB , ˆB ∗ ) for some absolute constant c . Note that for any j ∈ [ m ] and any z , E [ (cid:104) z , u j (cid:105)(cid:104) v j , y (cid:105) ] = z (cid:62) (( ˆB ∗ ) (cid:62) ( ˆB ˆB (cid:62) − I d ) B ) y = 0. Thus we have a sumof m mean-zero, independent sub-exponential random variables. We can now use Bernstein’sinequality to obtain, for any fixed z , y ∈ N k , P m (cid:88) j =1 (cid:104) z , u j (cid:105)(cid:104) v j , y (cid:105) ≥ s ≤ exp (cid:18) − c (cid:48) m min (cid:18) s dist ( ˆB , ˆB ∗ ) , s dist( ˆB , ˆB ∗ ) (cid:19)(cid:19) (46)Now union bound over all z , y ∈ N k to obtain P (cid:18) m (cid:107) ( ˆB ∗ ) (cid:62) ( ˆB ˆB (cid:62) − I d ) X (cid:62) i X i ˆB (cid:107) ≥ s (cid:19) ≤ k exp (cid:16) − c (cid:48) m min( s / dist ( ˆB , ˆB ∗ ) , s/ dist( ˆB , ˆB ∗ )) (cid:17) (47)Let s dist( ˆB , ˆB ∗ ) = max( ε, ε ) for some (cid:15) >
0, then it follows that min( s / dist ( ˆB , ˆB ∗ ) , s/ dist( ˆB , ˆB ∗ )) = ε . So we have P (cid:18) m (cid:107) ( ˆB ∗ ) (cid:62) ( ˆB ˆB (cid:62) − I d ) X (cid:62) i X i ˆB (cid:107) ≥ ˆB , ˆB ∗ ) max( ε, ε ) (cid:19) ≤ k e − c (cid:48) mε (48)Moreover, letting ε = ck log( rn )4 m for some constant c , and m ≥ ck log( rn ), we have P (cid:32) m (cid:107) ( ˆB ∗ ) (cid:62) ( ˆB ˆB (cid:62) − I d ) X (cid:62) i X i ˆB (cid:107) ≥ dist( ˆB , ˆB ∗ ) (cid:114) ck log( rn ) m (cid:33) ≤ k e − c k log( rn ) ≤ e − k log( rn ) (49)27or large enough constant c . Thus, noting that (cid:107) H i (cid:107) = (cid:107) m ( ˆB ∗ ) (cid:62) ( ˆB ˆB (cid:62) − I d ) X (cid:62) i X i ˆB (cid:107) , weobtain P (cid:18) (cid:107) H i (cid:107) ≥ c dist ( ˆB , ˆB ∗ ) k log( rn ) m (cid:19) ≤ e − k log( rn ) (50)Thus, using (45), we have P (cid:18) (cid:107) ( GD − C ) w ∗ (cid:107) ≥ c (cid:107) W ∗ (cid:107) dist ( ˆB , ˆB ∗ ) k log( rn ) m (cid:19) ≤ P (cid:32) krn (cid:107) W ∗ (cid:107) rn (cid:88) i =1 (cid:107) H i (cid:107) ≥ c (cid:107) W ∗ (cid:107) dist ( ˆB , ˆB ∗ ) k log( rn ) m (cid:33) = P (cid:32) rn rn (cid:88) i =1 (cid:107) H i (cid:107) ≥ c dist ( ˆB , ˆB ∗ ) k log( rn ) m (cid:33) ≤ rn P (cid:18) (cid:107) H (cid:107) ≥ c dist ( ˆB , ˆB ∗ ) k log( rn ) m (cid:19) ≤ e − k log( rn ) completing the proof. Lemma 5.
Let δ k = ck / √ log( rn ) √ m , then (cid:107) F (cid:107) F ≤ δ k − δ k (cid:107) W ∗ (cid:107) dist ( ˆB t , ˆB ∗ ) (51) with probability at least − e − k log( n ) .Proof. By Cauchy-Schwarz, we have (cid:107) F (cid:107) F ≤ (cid:107) G − (cid:107) (cid:107) ( GD − C ) w ∗ (cid:107) . Combining the boundon (cid:107) G − (cid:107) from Lemma 3 and the bound on (cid:107) ( GD − C ) w ∗ (cid:107) from Lemma 4 via a union boundyields the result.We next focus on showing concentration of the operator m A † A to the identity operator. Lemma 6.
Let δ (cid:48) k = ck √ d √ rnm for some absolute constant c . Then for any t , if δ (cid:48) k ≤ k , rn (cid:13)(cid:13)(cid:13)(cid:13)(cid:13)(cid:18) m A ∗ A ( Q t ) − Q t (cid:19) (cid:62) W t +1 (cid:13)(cid:13)(cid:13)(cid:13)(cid:13) ≤ δ (cid:48) k dist ( ˆB t , ˆB ∗ ) (52) with probability at least − e − d − e − k log( rn ) . roof. We drop superscripts t for simplicity. We first bound the norms of the rows of Q and W .Let q i ∈ R d be the i -th row of Q and let w i ∈ R k be the i -th row of W . Recall the computationof W from Lemma 2: W = W ∗ ˆ B (cid:62)∗ ˆ B − F = ⇒ w (cid:62) i = ( ˆ w ∗ i ) (cid:62) ˆB (cid:62)∗ ˆB − f (cid:62) i Thus (cid:107) q i (cid:107) = (cid:107) ˆB ˆB (cid:62) ˆB ∗ ˆ w ∗ i − ˆBf i − ˆB ∗ ˆ w ∗ i (cid:107) = (cid:107) ( ˆB ˆB (cid:62) − I d ) ˆB ∗ ˆ w ∗ i − ˆBf i (cid:107) ≤ (cid:107) ( ˆB ˆB (cid:62) − I d ) ˆB ∗ ˆ w ∗ i (cid:107) + 2 (cid:107) ˆBf i (cid:107) ≤ (cid:107) ( ˆB ˆB (cid:62) − I d ) ˆB ∗ (cid:107) (cid:107) ˆw ∗ i (cid:107) + 2 (cid:107) f i (cid:107) = 2 k dist ( ˆB , ˆB ∗ ) + 2 (cid:107) f i (cid:107) (53)Also recall that vec( F ) = G − ( GD − C ) ˆw ∗ from Lemma 2. From equation (38), the i -th row of F is given by: f i = ( G i ) − ( G i D i − C i ) w ∗ i Thus, using the Cauchy-Schwarz inequality and our previous bounds, (cid:107) f i (cid:107) ≤ (cid:107) ( G i ) − (cid:107) (cid:107) G i D i − C i (cid:107) (cid:107) w ∗ i (cid:107) ≤ (cid:107) ( G i ) − (cid:107) (cid:107) G i D i − C i (cid:107) k (54)where (54) follows by Assumption 2 From (50), we have that P (cid:16) (cid:107) G i D i − C i (cid:107) ≥ δ k dist ( ˆB , ˆB ∗ ) (cid:17) ≤ e − k log( rn ) Similarly, from equations (41) and (42), we have that P (cid:18) (cid:107) ( G i ) − (cid:107) ≥ − δ k ) (cid:19) ≤ e − k log( rn ) (55)Now plugging this back into (54) and assuming δ k ≤ , we obtain (cid:107) q i (cid:107) ≤ k dist ( ˆB , ˆB ∗ ) (cid:18) δ k (1 − δ k ) (cid:19) ≤ k dist ( ˆB , ˆB ∗ ) (56)with probability at least 1 − e − k log( rn ) . Likewise, to upper bound w i we have (cid:107) w i (cid:107) ≤ (cid:107) ˆB (cid:62) ˆB ∗ w ∗ i (cid:107) + 2 (cid:107) f i (cid:107) ≤ (cid:107) ˆB (cid:62) ˆB ∗ (cid:107) (cid:107) w ∗ i (cid:107) + 2 (cid:107) f i (cid:107) ≤ k + 2 δ k (1 − δ k ) dist ( ˆB , ˆB ∗ ) k (57) ≤ k (58)29here (57) holds with probability at least 1 − e − k log( n ) conditioning on the same event as in56, and holds almost surely as long as δ k ≤ /
2. Observe that the matrix m A ∗ A ( Q ) − Q can bere-written as 1 m A ∗ A ( Q ) − Q = 1 m rn (cid:88) i =1 m (cid:88) j =1 (cid:16) (cid:104) e i ( x ji ) (cid:62) , Q (cid:105) e i ( x ji ) (cid:62) − Q (cid:17) = 1 m rn (cid:88) i =1 m (cid:88) j =1 (cid:104) x ji , q i (cid:105) e i ( x ji ) (cid:62) − Q (59)Multiplying the transpose by rn W yields1 rn (cid:18) m A ∗ A ( Q ) − Q (cid:19) (cid:62) W = 1 rnm n (cid:88) i =1 m (cid:88) j =1 (cid:16) (cid:104) x ji , q i (cid:105) x ji ( w i ) (cid:62) − q i ( w i ) (cid:62) (cid:17) (60)where we have used the fact that ( Q ) (cid:62) W = (cid:80) ni =1 q i ( w i ) (cid:62) . We will argue similarly as inProposition 4.4.5 in [Vershynin, 2018] to bound the spectral norm of the d -by- k matrix in theRHS of (60).First, let S d − and S k − denote the unit spheres in d and k dimensions, respectively. Construct -nets N d and N k over S d − and S k − , respectively, such that |N d | ≤ d and |N k | ≤ k (whichis possible by Corollary 4.2.13 in [Vershynin, 2018]. Then, using equation 4.13 in [Vershynin,2018], we have (cid:13)(cid:13)(cid:13)(cid:13)(cid:13)(cid:13) rnm rn (cid:88) i =1 m (cid:88) j =1 (cid:16) (cid:104) x ji , q i (cid:105) x ji ( w i ) (cid:62) − q i ( w i ) (cid:62) (cid:17)(cid:13)(cid:13)(cid:13)(cid:13)(cid:13)(cid:13) ≤ u ∈N d , v ∈N k u (cid:62) rn (cid:88) i =1 m (cid:88) j =1 (cid:18) rnm (cid:104) x ji , q i (cid:105) x ji ( w i ) (cid:62) − rnm q i ( w i ) (cid:62) (cid:19) v = 2 max u ∈N d , v ∈N k rn (cid:88) i =1 m (cid:88) j =1 (cid:18) rnm (cid:104) x ji , q i (cid:105)(cid:104) u , x ji (cid:105)(cid:104) w i , v (cid:105) − rnm (cid:104) u , q i (cid:105)(cid:104) w i , v (cid:105) (cid:19) (61)By the I d -sub-gaussianity of x ji , the inner product (cid:104) u , x ji (cid:105) is sub-gaussian with norm at most c (cid:107) u (cid:107) = c for some absolute constant c for any fixed u ∈ N d . Similarly, (cid:104) x ji , q i (cid:105) is sub-gaussianwith norm at most (cid:107) q i (cid:107) ≤ c √ k dist( ˆB , ˆB ∗ ) with probability at least 1 − e − k log( rn ) , using(56). Further, since the sub-exponential norm of the product of two sub-gaussian randomvariables is at most the product of the sub-gaussian norms of the two random variables (Lemma2.7.7 in [Vershynin, 2018]), we have that (cid:104) x ji , q i (cid:105)(cid:104) u , x ji (cid:105) is sub-exponential with norm at most2 c √ k dist( ˆB , ˆB ∗ ). Further, rnm (cid:104) x ji , q i (cid:105)(cid:104) u , x ji (cid:105)(cid:104) w i , v (cid:105) is sub-exponential with norm at most2 c √ krnm dist( ˆB , ˆB ∗ ) (cid:104) w i , v (cid:105) ≤ c √ krnm dist( ˆB , ˆB ∗ ) (cid:107) w i (cid:107) ≤ c (cid:48) krnm dist( ˆB , ˆB ∗ )with probability at least 1 − e − k log( rn ) . Finally, note that E [ rnm (cid:104) x ji , q i (cid:105)(cid:104) u , x ji (cid:105)(cid:104) w i , v (cid:105) − rnm (cid:104) u , q i (cid:105)(cid:104) w i , v (cid:105) ] = 0. 30hus, we have a sum of rnm independent, mean zero sub-exponential random variables, so weapply Bernstein’s inequality. P rn (cid:88) i =1 m (cid:88) j =1 (cid:18) rnm (cid:104) x ji , q i (cid:105)(cid:104) u , x ji (cid:105)(cid:104) w i , v (cid:105) − rnm (cid:104) u , q i (cid:105)(cid:104) w i , v (cid:105) (cid:19) ≥ s ≤ exp (cid:18) c rnm min (cid:18) s k dist ( ˆB , ˆB ∗ ) , sk dist( ˆB , ˆB ∗ ) (cid:19)(cid:19) Union bounding over all u ∈ N d and v ∈ N k , and using the fact that we obtain P (cid:32)(cid:13)(cid:13)(cid:13)(cid:13)(cid:13) rn (cid:18) m A ∗ A ( Q ) − Q (cid:19) (cid:62) W (cid:13)(cid:13)(cid:13)(cid:13)(cid:13) ≥ s (cid:33) ≤ d + k exp (cid:18) c rnm min (cid:18) s k dist ( ˆB , ˆB ∗ ) , sk dist( ˆB , ˆB ∗ ) (cid:19)(cid:19) Let sk dist( ˆB , ˆB ∗ ) = max ( (cid:15), (cid:15) ) for some (cid:15) >
0, then (cid:15) = min (cid:16) s k dist ( ˆB , ˆB ∗ ) , sk dist( ˆB , ˆB ∗ ) (cid:17) . Further,let (cid:15) = d + k ) c rnm , then as long as (cid:15) ≤
1, we have P (cid:32)(cid:13)(cid:13)(cid:13)(cid:13)(cid:13) rn (cid:18) m A ∗ A ( Q ) − Q (cid:19) (cid:62) W (cid:13)(cid:13)(cid:13)(cid:13)(cid:13) ≥ c k dist( ˆB , ˆB ∗ ) (cid:112) d/ ( rnm ) (cid:33) ≤ e − d + k ) ≤ e − d completing the proof along with a union bound over the event (cid:107) F (cid:107) F ≤ δ k − δ k k dist( ˆB , ˆB ∗ ).Now we are finally ready to show the main result. B.3 Main Result
Lemma 7.
Define E := 1 − dist ( ˆB , ˆB ∗ ) and ¯ σ max , ∗ := max I∈ [ n ] , |I| = rn σ max ( √ rn W ∗I ) and ¯ σ min , ∗ := min I∈ [ n ] , |I| = rn σ min ( √ rn W ∗I ) , i.e. the maximum and minimum singular values of anymatrix that can be obtained by taking rn rows of √ rn W ∗ .Suppose that m ≥ c ( κ k log( rn ) /E + κ k d/ ( E rn )) for some absolute constant c . Then forany t and any η ≤ / (4¯ σ , ∗ ) , we havedist ( ˆB t +1 , ˆB ∗ ) ≤ (cid:0) − ηE ¯ σ , ∗ / (cid:1) / dist ( ˆB t , ˆB ∗ ) , with probability at least − e −
100 min( k log( rn ) ,d ) .Proof. Recall that W t +1 ∈ R rn × k and B t +1 ∈ R d × k are computed as follows: W t +1 = argmin W ∈ R rn × k rnm (cid:107)A ( W ∗ ˆB ∗ (cid:62) − W ˆB t (cid:62) ) (cid:107) (62) B t +1 = ˆB t − ηrnm (cid:16) A † A ( W t +1 ˆB t (cid:62) − W ∗ ˆB ∗ (cid:62) ) (cid:17) (cid:62) W t +1 (63)31et Q t = W t +1 ˆB t (cid:62) − W ∗ ˆB ∗ (cid:62) . We have B t +1 = ˆB t − ηrnm (cid:16) A † A ( Q t ) (cid:17) (cid:62) W t +1 = ˆB t − ηrn Q t (cid:62) W t +1 − ηrn (cid:18) m A † A ( Q t ) − Q t (cid:19) (cid:62) W t +1 (64)Now, multiply both sides by ˆB ∗ (cid:62) ⊥ . We have ˆB ∗ (cid:62) ⊥ B t +1 = ˆB ∗ (cid:62) ⊥ ˆB t − ηrn ˆB ∗ (cid:62) ⊥ Q t (cid:62) W t +1 − ηrn ˆB ∗ (cid:62) ⊥ (cid:18) m A † A ( Q t ) − Q t (cid:19) (cid:62) W t +1 = ˆB ∗ (cid:62) ⊥ ˆB t ( I k − ηrn W t +1 (cid:62) W t +1 ) − ηrn ˆB ∗ (cid:62) ⊥ (cid:18) m A † A ( Q t ) − Q t (cid:19) (cid:62) W t +1 (65)where the second equality follows because ˆB ∗ (cid:62) ⊥ Q t (cid:62) = ˆB ∗ (cid:62) ⊥ ˆB t W t +1 (cid:62) − ˆB ∗ (cid:62) ⊥ ˆB ∗ W ∗ (cid:62) = ˆB ∗ (cid:62) ⊥ ˆB t W t +1 (cid:62) .Then, writing the QR decomposition of B t +1 as B t +1 = ˆB t +1 R t +1 and multiplying both sides of(65) from the right by ( R t +1 ) − yields ˆB ∗ (cid:62) ⊥ ˆ B t +1 = (cid:32) ˆB ∗ (cid:62) ⊥ ˆB t ( I k − ηrn ( W t +1 ) (cid:62) W t +1 ) − ηrn ˆB ∗ (cid:62) ⊥ (cid:18) m A † A ( Q t ) − Q t (cid:19) (cid:62) W t +1 (cid:33) ( R t +1 ) − (66)Hence,dist( ˆ B t +1 , ˆ B ∗ )= (cid:13)(cid:13)(cid:13)(cid:13)(cid:13)(cid:32) ˆB ∗ (cid:62) ⊥ ˆB t ( I k − ηrn ( W t +1 ) (cid:62) W t +1 ) − ηrn ˆB ∗ (cid:62) ⊥ (cid:18) m A † A ( Q t ) − Q t (cid:19) (cid:62) W t +1 (cid:33) ( R t +1 ) − (cid:13)(cid:13)(cid:13)(cid:13)(cid:13) ≤ (cid:13)(cid:13)(cid:13) ˆB ∗ (cid:62) ⊥ ˆB t ( I k − ηrn ( W t +1 ) (cid:62) W t +1 ) (cid:13)(cid:13)(cid:13) (cid:13)(cid:13) ( R t +1 ) − (cid:13)(cid:13) + ηrn (cid:13)(cid:13)(cid:13)(cid:13)(cid:13) ˆB ∗ (cid:62) ⊥ (cid:18) m A † A ( Q t ) − Q t (cid:19) (cid:62) W t +1 (cid:13)(cid:13)(cid:13)(cid:13)(cid:13) (cid:13)(cid:13) ( R t +1 ) − (cid:13)(cid:13) (67)=: A + A . (68)where (67) follows by applying the triangle and Cauchy-Schwarz inequalities. We have thus splitthe upper bound on dist( B t +1 , ˆB ∗ ) into two terms, A and A . The second term, A , is smalldue to the concentration of m A † A to the identity operator, and the first term is strictly smallerthan dist( ˆ B t , ˆ B ∗ ). We start by controlling A : A = ηrn (cid:13)(cid:13)(cid:13)(cid:13)(cid:13) ˆB ∗ (cid:62) ⊥ (cid:18) m A † A ( Q t ) − Q t (cid:19) (cid:62) W t +1 (cid:13)(cid:13)(cid:13)(cid:13)(cid:13) (cid:13)(cid:13) ( R t +1 ) − (cid:13)(cid:13) ≤ ηrn (cid:13)(cid:13)(cid:13)(cid:13)(cid:13)(cid:18) m A † A ( Q t ) − Q t (cid:19) (cid:62) W t +1 (cid:13)(cid:13)(cid:13)(cid:13)(cid:13) (cid:13)(cid:13) ( R t +1 ) − (cid:13)(cid:13) (69) ≤ ηδ (cid:48) k dist( ˆB t , ˆB ∗ ) (cid:107) ( R t +1 ) − (cid:107) (70)32here (69) follows by Cauchy-Schwarz and the fact that ˆB ∗⊥ is normalized, and (70) follows withprobability at least 1 − e − d by Lemma 6. Next we control A : A = (cid:13)(cid:13)(cid:13) ˆB ∗ (cid:62) ⊥ ˆB t ( I k − ηrn ( W t +1 ) (cid:62) W t +1 ) (cid:13)(cid:13)(cid:13) (cid:107) ( R t +1 ) − (cid:107) ≤ (cid:107) ˆB ∗ (cid:62) ⊥ ˆB t (cid:107) (cid:13)(cid:13)(cid:13) I − ηrn ( W t +1 ) (cid:62) W t +1 (cid:13)(cid:13)(cid:13) (cid:107) ( R t +1 ) − (cid:107) = dist( ˆB t , ˆB ∗ ) (cid:13)(cid:13)(cid:13) I k − ηrn ( W t +1 ) (cid:62) W t +1 (cid:13)(cid:13)(cid:13) (cid:107) ( R t +1 ) − (cid:107) (71)The middle factor gives us contraction, which we bound as follows. First recall that W t +1 = W ∗ ˆ B ∗ (cid:62) ˆB t − F where F is defined in Lemma 2. By Lemma 5, we have that (cid:107) F (cid:107) ≤ δ k − δ k (cid:107) W ∗ (cid:107) dist( ˆB t , ˆB ∗ ) (72)with probability at least 1 − e − k log( rn ) , which we will use throughout the proof. Conditioningon this event, we have λ max (cid:16) ( W t +1 ) (cid:62) W t +1 (cid:17) = (cid:107) W ∗ ˆ B ∗ (cid:62) ˆB t − F (cid:107) ≤ (cid:107) W ∗ ˆ B ∗ (cid:62) ˆB t (cid:107) + 2 (cid:107) F (cid:107) ≤ (cid:107) W ∗ (cid:107) + 2 δ k (1 − δ k ) (cid:107) W ∗ (cid:107) dist ( ˆB t , ˆB ∗ ) ≤ (cid:107) W ∗ (cid:107) (73)where (73) follows under the assumption that δ k ≤ /
2. Thus, as long as η ≤ / (4¯ σ , ∗ ), wehave by Weyl’s Inequality: (cid:107) I k − ηrn ( W t +1 ) (cid:62) W t +1 (cid:107) ≤ − ηrn λ min (( W t +1 ) (cid:62) W t +1 ) (74)= 1 − ηrn λ min (( W ∗ ˆB ∗ (cid:62) ˆB t − F ) (cid:62) ( W ∗ ˆB ∗ (cid:62) ˆB t − F )) ≤ − ηrn σ ( W ∗ ( ˆB ∗ ) (cid:62) ˆB t ) + 2 ηrn σ max ( F (cid:62) W ∗ ( ˆB ∗ ) (cid:62) ˆB t ) − ηrn σ ( F ) (75) ≤ − ηrn σ ( W ∗ ) σ (( ˆB ∗ ) (cid:62) ˆB t ) + 2 ηrn (cid:107) F (cid:107) (cid:107) W ∗ ( ˆB ∗ ) (cid:62) ˆB t (cid:107) (76) ≤ − ηrn σ ( W ∗ ) σ (( ˆB ∗ ) (cid:62) ˆB t ) + 2 ηrn δ k − δ k (cid:107) W ∗ (cid:107) (77)= 1 − η ¯ σ , ∗ σ (( ˆB ∗ ) (cid:62) ˆB t ) + 2 η δ k − δ k ¯ σ , ∗ (78)where (75) follows by again applying Weyl’s inequality, under the condition that2 σ max ( F (cid:62) W ∗ ( ˆB ∗ ) (cid:62) ˆB t ) ≤ σ ( W ∗ ) σ (( ˆB ∗ ) (cid:62) ˆB t ), which we will enforce to be true (otherwisewe would not have contraction). Also, (76) follows by the Cauchy-Schawrz inequality, and weuse Lemma 5 to obtain (77). Lastly, (78) follows by the definitions of ¯ σ min , ∗ and ¯ σ max , ∗ . In orderto lower bound σ (( ˆB ∗ ) (cid:62) ˆB t ), note that σ (( ˆB ∗ ) (cid:62) ˆB t ) ≥ − (cid:107) ( ˆB ∗⊥ ) (cid:62) ˆB t (cid:107) = 1 − dist ( ˆB t , ˆB ∗ ) ≥ − dist ( ˆB , ˆB ∗ ) =: E (79)33s a result, defining ¯ δ k := δ k + δ (cid:48) k and combining (67), (70), (71), (78), and (79) yieldsdist( ˆB t , ˆB ∗ ) ≤ (cid:107) ( R t +1 ) − (cid:107) (1 − η ¯ σ , ∗ E + 2 η δ k − δ k ¯ σ , ∗ + ηδ (cid:48) k ) dist( ˆB t , B ∗ ) ≤ (cid:107) ( R t +1 ) − (cid:107) (1 − η ¯ σ , ∗ E + 2 η ¯ δ k − ¯ δ k ¯ σ , ∗ ) dist( ˆB t , B ∗ ) (80)where (80) follows from the fact that krn = (cid:107) W ∗ (cid:107) F ≤ k (cid:107) W ∗ (cid:107) = ⇒ ≤ (cid:107) W ∗ (cid:107) /rn ≤ ¯ σ , ∗ .All that remains to bound is (cid:107) ( R t +1 ) − (cid:107) . Define S t := m A † A ( Q t ) and observe that( R t +1 ) (cid:62) R t +1 = ( B t +1 ) (cid:62) B t +1 = ˆB t (cid:62) ˆB t − ηrn ( ˆB t (cid:62) S t (cid:62) W t +1 + ( W t +1 ) (cid:62) S t ˆB t ) + η ( rn ) ( W t +1 ) (cid:62) S t S t (cid:62) W t +1 = I k − ηrn ( ˆB t (cid:62) S t (cid:62) W t +1 + ( W t +1 ) (cid:62) S t ˆB t ) + η ( rn ) ( W t +1 ) (cid:62) S t S t (cid:62) W t +1 (81)thus, by Weyl’s Inequality, we have σ ( R t +1 ) ≥ − ηrn λ max ( ˆB t (cid:62) S t (cid:62) W t +1 + ( W t +1 ) (cid:62) S t ˆB t ) + η ( rn ) λ min (( W t +1 ) (cid:62) S t S t (cid:62) W t +1 ) ≥ − ηrn λ max ( ˆB t (cid:62) S t (cid:62) W t +1 + ( W t +1 ) (cid:62) S t ˆB t ) (82)where (82) follows because ( W t +1 ) (cid:62) S t S t (cid:62) W t +1 is positive semi-definite. Next, note that ηrn λ max ( ˆB t (cid:62) S t (cid:62) W t +1 + ( W t +1 ) (cid:62) S t ˆB t )= max x : (cid:107) x (cid:107) =1 ηrn x (cid:62) ˆB t (cid:62) ( S t ) (cid:62) W t +1 x + x (cid:62) ( W t +1 ) (cid:62) S t ˆB t x = max x : (cid:107) x (cid:107) =1 ηrn x (cid:62) ( W t +1 ) (cid:62) S t ˆB t x = max x : (cid:107) x (cid:107) =1 ηrn x (cid:62) ( W t +1 ) (cid:62) (cid:18) m A † A ( Q t ) − Q t (cid:19) ˆB t x + 2 ηrn x (cid:62) ( W t +1 ) (cid:62) Q t ˆB t x (83)We first consider the first term. We havemax x : (cid:107) x (cid:107) =1 ηrn x (cid:62) ( W t +1 ) (cid:62) (cid:18) m A † A ( Q t ) − Q t (cid:19) ˆB t x ≤ ηrn (cid:13)(cid:13)(cid:13)(cid:13) ( W t +1 ) (cid:62) (cid:18) m A † A ( Q t ) − Q t (cid:19)(cid:13)(cid:13)(cid:13)(cid:13) (cid:13)(cid:13)(cid:13) ˆB t (cid:13)(cid:13)(cid:13) ≤ ηδ (cid:48) k (84)where the last inequality follows with probability at least 1 − e − d − e − k log( rn ) from Lemma6. Next we turn to the second term in (83). We havemax x : (cid:107) x (cid:107) =1 ηrn x (cid:62) ( W t +1 ) (cid:62) Q t ˆB t x = max x : (cid:107) x (cid:107) =1 ηrn (cid:68) Q t , W t +1 xx (cid:62) ˆB t (cid:62) (cid:69) = max x : (cid:107) x (cid:107) =1 ηrn (cid:104) Q t , W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) (cid:105) − ηrn (cid:104) Q t , Fxx (cid:62) ˆB t (cid:62) (cid:105) (85)34or any x ∈ R k : (cid:107) x (cid:107) = 1, we have2 ηrn (cid:104) Q t , W ∗ ( ˆB ∗ ) (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) (cid:105) = 2 ηrn tr(( ˆB t ( W t +1 ) (cid:62) − ˆB ∗ ( W ∗ ) (cid:62) ) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) )= 2 ηrn tr(( ˆB t ˆB t (cid:62) ˆB ∗ W ∗ (cid:62) − ˆB t F (cid:62) − ˆB ∗ W ∗ (cid:62) ) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) )= 2 ηrn tr(( ˆB t ˆB t (cid:62) − I ) ˆB ∗ (cid:62) W ∗ (cid:62) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) ) − ηrn tr( ˆB t F (cid:62) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) )= 2 ηrn tr( ˆB t ⊥ ˆB ∗ (cid:62) W ∗ (cid:62) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) ) − ηrn tr( ˆB t F (cid:62) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) )= 2 ηrn tr( ˆB ∗ (cid:62) W ∗ (cid:62) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) ˆB t ⊥ ) − ηrn tr( ˆB t F (cid:62) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) )= − ηrn tr( F (cid:62) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ˆB t (cid:62) ˆB t ) (86)= − ηrn tr( F (cid:62) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) ) (87) ≤ ηrn (cid:107) F (cid:107) F (cid:107) W ∗ ˆB ∗ (cid:62) ˆB t xx (cid:62) (cid:107) F (88) ≤ ηrn (cid:107) F (cid:107) F (cid:107) W ∗ (cid:107) (cid:107) ˆB ∗ (cid:62) (cid:107) (cid:107) ˆB t (cid:107) (cid:107) xx (cid:62) (cid:107) F (89) ≤ ηrn (cid:107) F (cid:107) F (cid:107) W ∗ (cid:107) (90) ≤ η δ k − δ k ¯ σ , ∗ (91)where (86) follows since ˆB t (cid:62) ˆB t ⊥ = , (87) follows since ˆB t (cid:62) ˆB t ⊥ = I k , (88) and (89) follow by theCauchy-Schwarz inequality, (90) follows by the orthonormality of ˆB t and ˆB ∗ and (91) follows byLemma 5 and the definition of ¯ σ max , ∗ . Next, again for any x ∈ R k : (cid:107) x (cid:107) = 1, − ηrn (cid:104) Q t , Fxx (cid:62) ˆB t (cid:62) (cid:105) = − ηrn tr(( ˆB t ˆB t (cid:62) ˆB ∗ W ∗ (cid:62) − ˆB t F (cid:62) − ˆB ∗ W ∗ (cid:62) ) Fxx (cid:62) ˆB t (cid:62) )= − ηrn tr(( ˆB t ˆB t (cid:62) − I d ) ˆB ∗ W ∗ (cid:62) Fxx (cid:62) ˆB t (cid:62) ) + 2 ηrn tr( Fxx (cid:62) ˆB t (cid:62) ˆB t F (cid:62) )= − ηrn tr( ˆB ∗ W ∗ (cid:62) Fxx (cid:62) ˆB t (cid:62) B t ⊥ ) + 2 ηrn x (cid:62) F (cid:62) Fx = 2 ηrn x (cid:62) F (cid:62) Fx ≤ ηrn (cid:107) F (cid:107) ≤ η δ k (1 − δ k ) ¯ σ , ∗ (92)Thus, we have the following bound on the second term of (83):max x : (cid:107) x (cid:107) =1 ηrn (cid:104) Q t , W t +1 xx (cid:62) ˆB t (cid:62) (cid:105) ≤ η ¯ σ , ∗ (cid:18) δ k − δ k + δ k (1 − δ k ) (cid:19) ≤ η δ k (1 − δ k ) ¯ σ , ∗ (93)35ince δ k ≤ ⇒ δ k ≤ δ k . Therefore, using (82), (83), (84) and (93), we have σ ( R t +1 ) ≥ − ηδ (cid:48) k − η δ k (1 − δ k ) ¯ σ , ∗ ≥ − η ¯ δ k (1 − ¯ δ k ) ¯ σ , ∗ (94)where ¯ δ k = δ (cid:48) k + δ k . This means that (cid:107) ( R t +1 ) − (cid:107) ≤ (cid:18) − η ¯ δ k (1 − ¯ δ k ) ¯ σ , ∗ (cid:19) − / (95)Note that 1 − η ¯ δ k (1 − ¯ δ k ) ¯ σ , ∗ is strictly positive as long as ¯ δ k (1 − ¯ δ k ) <
1, which we will verifyshortly, due to our earlier assumption that η ≤ / (4¯ σ , ∗ ). Therefore, from (80), we havedist( ˆB t , ˆB ∗ ) ≤ (cid:113) − η ¯ δ k (1 − ¯ δ k ) ¯ σ , ∗ (cid:18) − η ¯ σ , ∗ E + 2 η ¯ δ k (1 − ¯ δ k ) ¯ σ , ∗ (cid:19) dist( ˆB t , ˆB ∗ )Next, let ¯ δ k < E / (25 · κ ). This implies that ¯ δ k < /
5. Then ¯ δ k / (1 − ¯ δ k ) < δ k / ≤ E / (5 κ ) ≤
1, validating (95). Further, it is easily seen that1 − ηE ¯ σ , ∗ + η ¯ δ k (1 − ¯ δ k ) ¯ σ , ∗ ≤ − η ¯ δ k (1 − ¯ δ k ) ¯ σ ≤ − ηE ¯ σ , ∗ / ˆB t , ˆB ∗ ) ≤ (cid:0) − ηE ¯ σ , ∗ / (cid:1) / dist( ˆB t , ˆB ∗ ) . Finally, recall that ¯ δ k = δ k + δ (cid:48) k = c (cid:18) k / √ log( rn ) √ m + k √ d √ rnm (cid:19) for some absolute constant c .Choosing m ≥ c (cid:48) ( κ k log( rn ) /E + κ k d/ ( E rn )) for another absolute constant c (cid:48) satisfies¯ δ k ≤ E / (25 · κ ). Also, we have conditioned on two events, described in Lemmas 5 and6, which occur with probability at least 1 − e − d − e − k log( rn ) ≥ − e −
100 min( k log( rn ) ,d ) ,completing the proof.Finally, Theorem 1 follows by recursively applying Lemma 7 and taking a union bound over all t ∈ [ T ]. Theorem 1.
Define E := 1 − dist ( B , B ∗ ) and ¯ σ max , ∗ := max I∈ [ n ] , |I| = rn σ max ( √ rn W ∗I ) and ¯ σ min , ∗ := min I∈ [ n ] , |I| = rn σ min ( √ rn W ∗I ) , i.e. the maximum and minimum singular values of anymatrix that can be obtained by taking rn rows of √ rn W ∗ . Suppose that m ≥ c ( κ k log( rn ) /E + κ k d/ ( E rn )) for some absolute constant c . Then for any t and any η ≤ / (4¯ σ , ∗ ) , we havedist ( B T , B ∗ ) ≤ (cid:0) − ηE ¯ σ , ∗ / (cid:1) T/ dist ( B , B ∗ ) , with probability at least − T e −
100 min( k log( rn ) ,d ) . Further Discussion of Related Works
The majority of methods for personalized federated learning do not involve dimensionality-reduction at the local level. For instance, Jiang et al. [2019] and Fallah et al. [2020] obtainfull-dimensional models advantageous for local fine-tuning via SGD by applying model-agnosticmeta-learning (MAML) methods to federated learning. On the other hand, many works proposeto learn multiple full-dimensional models, typically a single global model and unique local modelsfor each client. Hanzely and Richt´arik [2020], T Dinh et al. [2020], and Li et al. [2020] add aregularization term to the standard federated learning objective which penalizes local modelsfrom being too far from a global model. Huang et al. [2020] take a similar approach but use aregularizer with an attention-inducing function to encode relationships between the local models,similar to [Smith et al., 2017]. Meanwhile, Deng et al. [2020] and Mansour et al. [2020] also usefull-dimensional local and global models, and propose to linearly combine them to obtain themodel for each client, a framework for which they are able to provide generalization bounds onnew client data. Mansour et al. [2020] additionally analyze two other methods for personalization,but these methods require combining client data and hence are not applicable to federatedlearning.
References
Manoj Ghuhan Arivazhagan, Vinay Aggarwal, Aaditya Kumar Singh, and Sunav Choudhary.Federated learning with personalization layers. arXiv preprint arXiv:1912.00818 , 2019.Yoshua Bengio, Aaron Courville, and Pascal Vincent. Representation learning: A review and newperspectives.
IEEE transactions on pattern analysis and machine intelligence , 35(8):1798–1828,2013.Sebastian Caldas, Sai Meher Karthik Duddu, Peter Wu, Tian Li, Jakub Koneˇcn`y, H BrendanMcMahan, Virginia Smith, and Ameet Talwalkar. Leaf: A benchmark for federated settings. arXiv preprint arXiv:1812.01097 , 2018.Fei Chen, Mi Luo, Zhenhua Dong, Zhenguo Li, and Xiuqiang He. Federated meta-learning withfast convergence and efficient communication. arXiv preprint arXiv:1802.07876 , 2018.Yuejie Chi, Yue M Lu, and Yuxin Chen. Nonconvex optimization meets low-rank matrixfactorization: An overview.
IEEE Transactions on Signal Processing , 67(20):5239–5269, 2019.Gregory Cohen, Saeed Afshar, Jonathan Tapson, and Andre Van Schaik. Emnist: Extendingmnist to handwritten letters. In , pages 2921–2926. IEEE, 2017.Yuyang Deng, Mohammad Mahdi Kamani, and Mehrdad Mahdavi. Adaptive personalizedfederated learning. arXiv preprint arXiv:2003.13461 , 2020.Simon S. Du, Wei Hu, Sham M. Kakade, Jason D. Lee, and Qi Lei. Few-shot learning via learningthe representation, provably, 2020. 37lireza Fallah, Aryan Mokhtari, and Asuman Ozdaglar. Personalized federated learning: Ameta-learning approach, 2020.Farzin Haddadpour, Mohammad Mahdi Kamani, Aryan Mokhtari, and Mehrdad Mahdavi.Federated learning with compression: Unified analysis and sharp guarantees. arXiv preprintarXiv:2007.01154 , 2020.Filip Hanzely and Peter Richt´arik. Federated learning of a mixture of global and local models. arXiv preprint arXiv:2002.05516 , 2020.Yutao Huang, Lingyang Chu, Zirui Zhou, Lanjun Wang, Jiangchuan Liu, Jian Pei, and YongZhang. Personalized federated learning: An attentive collaboration approach. arXiv preprintarXiv:2007.03797 , 2020.Prateek Jain, Praneeth Netrapalli, and Sujay Sanghavi. Low-rank matrix completion usingalternating minimization.
Proceedings of the 45th annual ACM symposium on Symposium ontheory of computing - STOC ’13 , 2013.Yihan Jiang, Jakub Koneˇcn`y, Keith Rush, and Sreeram Kannan. Improving federated learningpersonalization via model agnostic meta learning. arXiv preprint arXiv:1909.12488 , 2019.Sai Praneeth Karimireddy, Satyen Kale, Mehryar Mohri, Sashank Reddi, Sebastian Stich, andAnanda Theertha Suresh. Scaffold: Stochastic controlled averaging for federated learning. In
International Conference on Machine Learning , pages 5132–5143. PMLR, 2020.Mikhail Khodak, Maria-Florina F Balcan, and Ameet S Talwalkar. Adaptive gradient-based meta-learning methods. In
Advances in Neural Information Processing Systems , pages 5915–5926,2019.Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images.2009.Yann LeCun, Yoshua Bengio, and Geoffrey Hinton. Deep learning. nature , 521(7553):436–444,2015.Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and VirginiaSmith. Federated optimization in heterogeneous networks. arXiv preprint arXiv:1812.06127 ,2018.Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and VirginiaSmith. Feddane: A federated newton-type method. In , pages 1227–1231. IEEE, 2019.Tian Li, Shengyuan Hu, Ahmad Beirami, and Virginia Smith. Federated multi-task learning forcompeting constraints. arXiv preprint arXiv:2012.04221 , 2020.Paul Pu Liang, Terrance Liu, Liu Ziyin, Ruslan Salakhutdinov, and Louis-Philippe Morency.Think locally, act globally: Federated learning with local and global representations. arXivpreprint arXiv:2001.01523 , 2020.Yishay Mansour, Mehryar Mohri, Jae Ro, and Ananda Theertha Suresh. Three approaches forpersonalization with applications to federated learning. arXiv preprint arXiv:2002.10619 , 2020.38rendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas.Communication-efficient learning of deep networks from decentralized data. In
ArtificialIntelligence and Statistics , pages 1273–1282. PMLR, 2017.Dohyung Park, Anastasios Kyrillidis, Constantine Caramanis, and Sujay Sanghavi. Findinglow-rank solutions via nonconvex matrix factorization, efficiently and provably.
SIAM Journalon Imaging Sciences , 11(4):2165–2204, 2018.Reese Pathak and Martin J Wainwright. Fedsplit: An algorithmic framework for fast federatedoptimization. arXiv preprint arXiv:2005.05238 , 2020.Jeffrey Pennington, Richard Socher, and Christopher D Manning. Glove: Global vectors forword representation. In
Proceedings of the 2014 conference on empirical methods in naturallanguage processing (EMNLP) , pages 1532–1543, 2014.Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Koneˇcn`y,Sanjiv Kumar, and H Brendan McMahan. Adaptive federated optimization. arXiv preprintarXiv:2003.00295 , 2020.Virginia Smith, Chao-Kai Chiang, Maziar Sanjabi, and Ameet S Talwalkar. Federated multi-tasklearning. In
Advances in neural information processing systems , pages 4424–4434, 2017.Canh T Dinh, Nguyen Tran, and Tuan Dung Nguyen. Personalized federated learning withmoreau envelopes.
Advances in Neural Information Processing Systems , 33, 2020.Nilesh Tripuraneni, Chi Jin, and Michael I. Jordan. Provable meta-learning of linear representa-tions, 2020.Stephen Tu, Ross Boczar, Max Simchowitz, Mahdi Soltanolkotabi, and Ben Recht. Low-ranksolutions of linear matrix equations via procrustes flow. In
International Conference onMachine Learning , pages 964–973. PMLR, 2016.Roman Vershynin.
High-dimensional probability: An introduction with applications in datascience , volume 47. Cambridge university press, 2018.Jianyu Wang, Qinghua Liu, Hao Liang, Gauri Joshi, and H Vincent Poor. Tackling theobjective inconsistency problem in heterogeneous federated optimization. arXiv preprintarXiv:2007.07481 , 2020.Kangkang Wang, Rajiv Mathews, Chlo´e Kiddon, Hubert Eichner, Fran¸coise Beaufays, and DanielRamage. Federated evaluation of on-device personalization. arXiv preprint arXiv:1910.10252 ,2019.Tao Yu, Eugene Bagdasaryan, and Vitaly Shmatikov. Salvaging federated learning by localadaptation. arXiv preprint arXiv:2002.04758 , 2020.Qinqing Zheng and John Lafferty. Convergence analysis for rectangular matrix completion usingburer-monteiro factorization and gradient descent. arXiv preprint arXiv:1605.07051 , 2016.Kai Zhong, Prateek Jain, and Inderjit S Dhillon. Efficient matrix sensing using rank-1 gaussianmeasurements. In