|
## DEEP LEARNING VIA MESSAGE PASSING ALGORITHMS |
|
### BASED ON BELIEF PROPAGATION |
|
|
|
**Anonymous authors** |
|
Paper under double-blind review |
|
|
|
ABSTRACT |
|
|
|
Message-passing algorithms based on the Belief Propagation (BP) equations constitute a well-known distributed computational scheme. They yield exact marginals |
|
on tree-like graphical models and have also proven to be effective in many problems |
|
defined on loopy graphs, from inference to optimization, from signal processing |
|
to clustering. The BP-based schemes are fundamentally different from stochastic |
|
gradient descent (SGD), on which the current success of deep networks is based. |
|
In this paper, we present and adapt to mini-batch training on GPUs a family of |
|
BP-based message-passing algorithms with a reinforcement term that biases distributions towards locally entropic solutions. These algorithms are capable of |
|
training multi-layer neural networks with performance comparable to SGD heuristics in a diverse set of experiments on natural datasets including multi-class image |
|
classification and continual learning, while being capable of yielding improved |
|
performances on sparse networks. Furthermore, they allow to make approximate |
|
Bayesian predictions that have higher accuracy than point-wise ones. |
|
|
|
1 INTRODUCTION |
|
|
|
Belief Propagation is a method for computing marginals and entropies in probabilistic inference |
|
problems (Bethe, 1935; Peierls, 1936; Gallager, 1962; Pearl, 1982). These include optimization |
|
problems as well once they are written as zero temperature limit of a Gibbs distribution that uses the |
|
cost function as energy. Learning is one particular case, in which one wants to minimize a cost which |
|
is a data dependent loss function. These problems are generally intractable and message-passing |
|
techniques have been particularly successful at providing principled approximations through efficient |
|
distributed computations. |
|
|
|
A particularly compact representation of inference/optimization problems that is used to build |
|
massage-passing algorithms is provided by factor graphs. A factor graph is a bipartite graph composed |
|
of variables nodes and factor nodes expressing the interactions among variables. Belief Propagation |
|
is exact for tree-like factor graphs (Yedidia et al., 2003)), where the Gibbs distribution is naturally |
|
factorized, whereas it is approximate for graphs with loops. Still, loopy BP is routinely used with |
|
success in many real world applications ranging from error correcting codes, vision, clustering, just to |
|
mention a few. In all these problems, loops are indeed present in the factor graph and yet the variables |
|
are weakly correlated at long range and BP gives good results. A field in which BP has a long history |
|
is the statistical physics of disordered systems where it is known as Cavity Method (Mézard et al., |
|
1987). It has been used to study the typical properties of spin glass models which represent binary |
|
variables interacting through random interactions over a given graph. It is very well known that in |
|
spin glass models defined on complete graphs and in locally tree-like random graphs, which are |
|
both loopy, the weak correlation conditions between variables may hold and BP give asymptotic |
|
exact results (Mézard & Montanari, 2009). Here we will mostly focus on neural networks ±1 binary |
|
weights and sign activation functions, for which the messages and the marginals can be described |
|
simply by the difference between the probabilities associated with the +1 and -1 states, the so called |
|
_magnetizations. The effectiveness of BP for deep learning has never been numerically tested in a_ |
|
systematic way, however there is clear evidence that the weak correlation decay condition does not |
|
hold and thus BP convergence and approximation quality is unpredictable. |
|
|
|
In this paper we explore the effectiveness of a variant of BP that has shown excellent convergence |
|
properties in hard optimization problems and in non-convex shallow networks. It goes under the |
|
|
|
|
|
----- |
|
|
|
name of focusing BP (fBP) and is based on a probability distribution, a likelihood, that focuses on |
|
highly entropic wide minima, neglecting the contribution to marginals from narrow minima even |
|
when they are the majority (and hence dominate the Gibbs distribution). This version of BP is thus |
|
expected to give good results only in models that have such wide entropic minima as part of their |
|
energy landscape. As discussed in (Baldassi et al., 2016a), a simple way to define fBP is to add a |
|
"reinforcement" term to the BP equations: an iteration-dependent local field is introduced for each |
|
variable, with an intensity proportional to its marginal probability computed in the previous iteration |
|
step. This field is gradually increased until the entire system becomes fully biased on a configuration. |
|
The first version of reinforced BP was introduced in (Braunstein & Zecchina, 2006) as a heuristic |
|
algorithm to solve the learning problem in shallow binary networks. Baldassi et al. (2016a) showed |
|
that this version of BP is a limiting case of fBP, i.e., BP equations written for a likelihood that uses |
|
the local entropy function instead of the error (energy) loss function. As discussed in depth in that |
|
study, one way to introduce a likelihood that focuses on highly entropic regions is to create y coupled |
|
replicas of the original system. fBP equations are obtained as BP equations for the replicated system. |
|
It turns out that the fBP equations are identical to the BP equations for the original system with the |
|
only addition of a self-reinforcing term in the message passing scheme. The fBP algorithm can be |
|
used as a solver by gradually increasing the effect of the reinforcement: one can control the size of |
|
the regions over which the fBP equations estimate the marginals by tuning the parameters that appear |
|
in the expression of the reinforcement, until the high entropy regions reduce to a single configuration. |
|
Interestingly, by keeping the size of the high entropy region fixed, the fBP fixed point allows one to |
|
estimate the marginals and entropy relative to the region. |
|
|
|
In this work, we present and adapt to GPU computation a family of fBP inspired message passing |
|
algorithms that are capable of training multi-layer neural networks with generalization performance |
|
and computational speed comparable to SGD. This is the first work that shows that learning by |
|
message passing in deep neural networks 1) is possible and 2) is a viable alternative to SGD. Our |
|
version of fBP adds the reinforcement term at each mini-batch step in what we call the Posterioras-Prior (PasP) rule. Furthermore, using the message-passing algorithm not as a solver but as an |
|
estimator of marginals allows us to make locally Bayesian predictions, averaging the predictions |
|
over the approximate posterior. The resulting generalization error is significantly better than those of |
|
the solver, showing that, although approximate, the marginals of the weights estimated by messagepassing retain useful information. Consistently with the assumptions underlying fBP, we find that |
|
the solutions provided by the message passing algorithms belong to flat entropic regions of the loss |
|
landscape and have good performance in continual learning tasks and on sparse networks as well. |
|
|
|
We also remark that our PasP update scheme is of independent interest and can be combined with |
|
different posterior approximation techniques. |
|
|
|
The paper is structured as follows: in Sec. 2 we give a brief review of some related works. In Sec. 3 |
|
we provide a detailed description of the message-passing equations and of the high level structure |
|
of the algorithms. In Sec. 4 we compare the performance of the message passing algorithms versus |
|
SGD based approaches in different learning settings. |
|
|
|
2 RELATED WORKS |
|
|
|
The literature on message passing algorithms is extensive, we refer to Mézard & Montanari (2009) |
|
and Zdeborová & Krzakala (2016) for a general overview. More related to our work, multilayer |
|
message-passing algorithms have been developed in inference contexts (Manoel et al., 2017; Fletcher |
|
et al., 2018), where they have been shown to produce exact marginals under certain statistical |
|
assumptions on (unlearned) weight matrices. |
|
|
|
The properties of message-passing for learning shallow neural networks have been extensively studied |
|
(see Baldassi et al. (2020) and reference therein). Barbier et al. (2019) rigorously show that message |
|
passing algorithms in generalized linear models perform asymptotically exact inference under some |
|
statistical assumptions. Dictionary learning and matrix factorization are harder problems closely |
|
related to deep network learning problems, in particular to the modelling of a single intermediate |
|
layer. They have been approached using message passing in Kabashima et al. (2016) and Parker |
|
et al. (2014), although the resulting predictions are found to be asymptotically inexact (Maillard |
|
et al., 2021). The same problem is faced by the message passing algorithm recently proposed for a |
|
multi-layer matrix factorization scenario (Zou et al., 2021). Unfortunately, our framework as well |
|
|
|
|
|
----- |
|
|
|
doesn’t yield asymptotic exact predictions. Nonetheless, it gives a message passing heuristic that for |
|
the first time is able to train deep neural networks on natural datasets, therefore sets a reference for |
|
the algorithmic applications of this research line. |
|
|
|
A few papers advocate the success of SGD to the geometrical structure (smoothness and flatness) of |
|
the loss landscape in neural networks (Baldassi et al., 2015; Chaudhari et al., 2017; Garipov et al., |
|
2018; Li et al., 2018; Pittorino et al., 2021; Feng & Tu, 2021). These considerations do not depend on |
|
the particular form of the SGD dynamics and should extend also to other types of algorithms, although |
|
SGD is by far the most popular choice among NNs practitioners due to its simplicity, flexibility, |
|
speed, and generalization performance. |
|
|
|
While our work focuses on message passing schemes, some of the ideas presented here, such as |
|
the PasP rule, can be combined with algorithms for Bayesian neural networks’ training (HernándezLobato & Adams, 2015; Wu et al., 2018). Recent work extends BP by combining it with graph |
|
neural networks (Kuck et al., 2020; Satorras & Welling, 2021). Finally, some work in computational |
|
neuroscience shows similarities to our approach (Rao, 2007). |
|
|
|
3 LEARNING BY MESSAGE PASSING |
|
|
|
3.1 POSTERIOR-AS-PRIOR UPDATES |
|
|
|
We consider a multi-layer perceptron with L hidden neuron layers, having weight and bias parameters |
|
_W = {W_ _[ℓ], b[ℓ]}ℓ[L]=0[. We allow for stochastic activations][ P][ ℓ][(][x][ℓ][+1][|][z][ℓ][)][, where][ z][ℓ]_ [is the neuron’s pre-] |
|
activation vector for layer ℓ, and P _[ℓ]_ is assumed to be factorized over the neurons. If no stochasticity |
|
is present, P _[ℓ]_ just encodes an element-wise activation function. The probability of output y given an |
|
input x is then given by |
|
|
|
|
|
_P_ _[ℓ][+1](x[ℓ][+1]_ _| W_ _[ℓ]x[ℓ]_ + b[ℓ]), (1) |
|
_ℓ=0_ |
|
|
|
Y |
|
|
|
|
|
_P_ (y | x, W) = _dx[1:][L]_ |
|
Z |
|
|
|
|
|
where for convenience we defined x[0] = x and x[L][+1] = y. In a Bayesian framework, given a training |
|
set D = {(xn, yn)}n and a prior distribution over the weights qθ(W) in some parametric family, the |
|
posterior distribution is given by |
|
|
|
|
|
_P_ (yn **_xn,_** ) qθ( ). (2) |
|
_|_ _W_ _W_ |
|
|
|
|
|
_P_ (W | D, θ) ∝ |
|
|
|
|
|
Here ∝ denotes equality up to a normalization factor. Using the posterior one can compute the Bayesian prediction for a new data-point x through the average P (y | x, D, θ) = |
|
_dW P_ (y | x, W) P (W | D, θ). Unfortunately, the posterior is generically intractable due to the |
|
hard-to-compute normalization factor. On the other hand, we are mainly interested in training a |
|
R |
|
distribution that covers wide minima of the loss landscape that generalize well (Baldassi et al., 2016a) |
|
and in recovering pointwise estimators within these regions. The Bayesian modeling becomes an |
|
auxiliary tool to set the stage for the message passing algorithms seeking flat minima. We also need |
|
a formalism that allows for mini-batch training to speed-up the computation and deal with large |
|
datasets. Therefore, we devise an update scheme that we call Posterior-as-Prior (PasP), where we |
|
evolve the parameters θ[t] of a distribution qθt ( ) computed as an approximate mini-batch posterior, |
|
_W_ |
|
in such a way that the outcome of the previous iteration becomes the prior in the following step. In |
|
the PasP scheme, θ[t] retains the memory of past observations. We also add an exponential factor ρ, |
|
that we typically set close to 1, tuning the forgetting rate and playing a role similar to the learning |
|
rate in SGD. Given a mini-batch (X _[t], y[t]) sampled from the training set at time t and a scalar ρ > 0,_ |
|
the PasP update reads |
|
|
|
_ρ_ |
|
_qθt+1_ ( ) _P_ ( **_y[t], X_** _[t], θ[t])_ _,_ (3) |
|
_W_ _≈_ _W |_ |
|
|
|
where denotes approximate equality and we do not account for the normalization factor. A first |
|
_≈_ |
|
approximation may be needed in the computation of the posterior, a second to project the approximate |
|
posterior onto the distribution manifold spanned by θ (Minka, 2001). In practice, we will consider |
|
factorized approximate posterior in an exponential family and priors qθ in the same family, although |
|
Eq. 3 generically allow for more refined approximations. |
|
|
|
|
|
----- |
|
|
|
Notice that setting ρ = 1, the batch-size to 1, and taking a single pass over the dataset, we recover |
|
the Assumed Density Filtering algorithm (Minka, 2001). For large enough ρ (including ρ = 1), the |
|
iterations of qθt will concentrate on a pointwise estimator. This mechanism mimics the reinforcement heuristic commonly used to turn Belief Propagation into a solver for constrained satisfaction |
|
problems (Braunstein & Zecchina, 2006) and related to flat-minima discovery (see focusing-BP in |
|
Baldassi et al. (2016a)). A different prior updating mechanism which can be understood as empirical |
|
Bayes has been used in Baldassi et al. (2016b). |
|
|
|
3.2 INNER MESSAGE PASSING LOOP |
|
|
|
While the PasP rule takes care of the reinforcement heuristic across mini-batches, we compute the |
|
mini-batch posterior in Eq. 3 using message passing approaches derived from Belief Propagation. |
|
BP is an iterative scheme for computing marginals and entropies of statistical models Mézard & |
|
Montanari (2009). It is most conveniently expressed on factor graphs, that is bipartite graphs where |
|
the two sets of nodes are called variable nodes and factor nodes. They respectively represent the |
|
variables involved in the statistical model and their interactions. Message from factor nodes to |
|
variable nodes and viceversa are exchanged along the edges of the factor graph for a certain number |
|
of BP iterations or until a fixed point is reached. |
|
|
|
The factor graph for P (W | X _[t], y[t], θ[t]) can be derived from Eq. 2, with the following additional_ |
|
specifications. For simplicity, we will ignore the bias term in each layer. We assume factorized |
|
_qθt_ ( ), each factor parameterized by its first two moments. In what follows, we drop the PasP |
|
_W_ |
|
iteration index t. For each example (xn, yn) in the mini-batch, we introduce the auxiliary variables |
|
**_x[ℓ]n[, ℓ]_** [= 1][, . . ., L][, representing the layers’ activations. For each example, each neuron in the network] |
|
contributes a factor node to the factor graph. The scalar components of the weight matrices and |
|
the activation vectors become variable nodes. This construction is presented in Appendix A, where |
|
we also derive the message update rules on the factor graph. The factor graph thus defined is |
|
extremely loopy and straightforward iteration of BP has convergence issues. Moreover, in presence |
|
of a homogeneous prior over the weights, the neuron permutation symmetry in each hidden layer |
|
induces a strongly attractive symmetric fixed point that hinders learning. We work around these |
|
issues by breaking the symmetry at time t = 0 with an inhomogeneous prior. In our experiments |
|
a little initial heterogeneity is sufficient to obtain specialized neurons at each following time step. |
|
Additionally, we do not require message passing convergence in the inner loop (see Algorithm 1) but |
|
perform one or a few iterations for each θ update. We also include an inertia term commonly called |
|
damping factor in the message updates (see B.2). As we shall discuss, these simple rules suffice to |
|
train deep networks by message passing. |
|
|
|
For the inner loop we adapt to deep neural networks four different message passing algorithms, all of |
|
which are well known to the literature although derived in simpler settings: Belief Propagation (BP), |
|
BP-Inspired (BPI) message passing, mean-field (MF), and approximate message passing (AMP). The |
|
last three algorithms can be considered approximations of the first one. In the following paragraphs |
|
we will discuss their common traits, present the BP updates as an example, and refer to Appendix A |
|
for an in-depth exposition. For all algorithms, message updates can be divided in a forward pass |
|
and backward pass, as also done in (Fletcher et al., 2018) in a multi-layer inference setting. The BP |
|
algorithm is compactly reported in Algorithm 1. |
|
|
|
**Meaning of messages.** All the messages involved in the message passing can be understood in |
|
terms of cavity marginals or full marginals (as mentioned in the introduction BP is also known as |
|
Cavity Method, see (Mézard & Montanari, 2009)). Of particular relevance are m[ℓ]ki [and][ σ]ki[ℓ] [, denoting] |
|
the mean and variance of the weights Wki[ℓ] [. The quantities][ ˆ]x[ℓ]in [and][ ∆]in[ℓ] [instead denote the mean and] |
|
variance of the i-th neuron’s activation in layer ℓ for a given input xn. |
|
|
|
**Scalar free energies.** All message passing schemes are conveniently expressed in terms of two |
|
functions that correspond to the effective free energy (Zdeborová & Krzakala, 2016) of a single |
|
|
|
|
|
----- |
|
|
|
neuron and of a single weight respectively : |
|
|
|
_ϕ[ℓ](B, A, ω, V ) = log_ dx dz e[−] 2[1] _[Ax][2][+][Bx]_ _P_ _[ℓ]_ (x|z) e[−] [(][ω]2[−]V[z][)2] _ℓ_ = 1, . . ., L (4) |
|
Z |
|
|
|
|
|
_ψ(H, G, θ) = log_ dw e[−] 2[1] _[G][2][w][2][+][Hw]_ _qθ(w)_ (5) |
|
Z |
|
|
|
|
|
Notice that for common deterministic activations such as ReLU and sign, the function ϕ has analytic |
|
and smooth expressions (see Appendix A.8). The same holds for the function ψ when qθ(w) is |
|
Gaussian (continuous weights) or a mixture of atoms (discrete weights). At the last layer we impose |
|
_P_ _[L][+1](y|z) = I(y = sign(z)) in binary classification tasks and P_ _[L][+1](y|z) = I(y = arg max(z))_ |
|
in multi-class classification (see Appendix A.9). While in our experiments we use hard constraints |
|
for the final output, therefore solving a constraint satisfaction problem, it would be interesting to also |
|
consider soft constraints and introduce a temperature, but this is beyond the scope of our work. |
|
|
|
**Start and end of message passing.** At the beginning of a new PasP iteration t, we reset the |
|
messages (see Appendix A) and run message passing for τmax iterations. We then compute the new |
|
prior’s parameters θ[t][+1] from the posterior given by the message passing. |
|
|
|
**BP Forward pass.** After initialization of the messages at time τ = 0, for each following time we |
|
propagate a set of message from the first to the last layer and then another set from the last to the first. |
|
For an intermediate layer ℓ the forward pass reads |
|
|
|
_xˆ[ℓ,τ]in→k_ = _∂Bϕ[ℓ]_ []Bin[ℓ,τ]→[−]k[1][, A]in[ℓ,τ] _[−][1], ωin[ℓ][−][1][,τ]_ _, Vin[ℓ][−][1][,τ]_ (6) |
|
|
|
|
|
∆[ℓ,τ]in = _∂B[2]_ _[ϕ][ℓ]_ []Bin[ℓ,τ] _[−][1], A[ℓ,τ]in_ _[−][1], ωin[ℓ][−][1][,τ]_ _, Vin[ℓ][−][1][,τ]_ (7) |
|
|
|
|
|
_m[ℓ,τ]ki_ _n_ = _∂H_ _ψ(Hki[ℓ,τ]_ _[−]n[1][, G]ki[ℓ,τ]_ _[−][1], θki[ℓ]_ [)] (8) |
|
_→_ _→_ |
|
|
|
_σki[ℓ,τ]_ = _∂H[2]_ _[ψ][(][H]ki[ℓ,τ]_ _[−][1], G[ℓ,τ]ki_ _[−][1], θki[ℓ]_ [)] (9) |
|
|
|
2 |
|
|
|
_Vkn[ℓ,τ]_ = _m[ℓ,τ]ki_ _n_ ∆[ℓ,τ]in [+][ σ]ki[ℓ,τ] [(ˆ]x[ℓ,τ]in _k[)][2][ +][ σ]ki[ℓ,τ]_ [∆]in[ℓ,τ] (10) |
|
|
|
_→_ _→_ |
|
|
|
Xi |
|
|
|
_ωkn[ℓ,τ]→i_ = _m[ℓ,τ]ki[′]→nx[ˆ][ℓ,τ]i[′]n→k_ (11) |
|
|
|
_iX[′]≠_ _i_ |
|
|
|
The equations for the first layer differ slightly and in an intuitive way from the ones above (see |
|
Appendix A.3). |
|
|
|
**BP Backward pass.** The backward pass updates a set of messages from the last to the first layer: |
|
|
|
_gkn[ℓ,τ]→i_ = _∂ωϕ[ℓ][+1][ ]Bkn[ℓ][+1][,τ]_ _, A[ℓ]kn[+1][,τ]_ _, ωkn[ℓ,τ]→i[, V]kn[ ℓ,τ]_ (12) |
|
|
|
|
|
Γ[ℓ,τ]kn = _−∂ω[2]_ _[ϕ][ℓ][+1][ ]Bkn[ℓ][+1][,τ]_ _, A[ℓ]kn[+1][,τ]_ _, ωkn[ℓ,τ]_ _[, V]kn[ ℓ,τ]_ (13) |
|
2 |
|
|
|
_A[ℓ,τ]in_ = _k_ (m[ℓ,τ]ki→n[)][2][ +][ σ]ki[ℓ,τ] Γ[ℓ,τ]kn _[−]_ _[σ]ki[ℓ,τ]_ _gkn[ℓ,τ]→i_ (14) |
|
X |
|
|
|
_Bin[ℓ,τ]→k_ = _m[ℓ,τ]k[′]i→n[g]k[ℓ,τ][′]n→i_ (15) |
|
|
|
_kX[′]≠_ _k_ |
|
|
|
2 |
|
|
|
_G[ℓ,τ]ki_ = _n_ (ˆx[ℓ,τ]in→k[)][2][ + ∆]in[ℓ,τ] Γ[ℓ,τ]kn _[−]_ [∆]in[ℓ,τ] _gkn[ℓ,τ]→i_ (16) |
|
X |
|
|
|
_Hki[ℓ,τ]→n_ = _xˆ[ℓ,τ]in[′]→k[g]kn[ℓ,τ][′]→i_ (17) |
|
|
|
_nX[′]≠_ _n_ |
|
|
|
|
|
As with the forward pass, we add the caveat that for the last layer the equations are slightly different |
|
from the ones above. |
|
|
|
|
|
----- |
|
|
|
**Computational complexity** The message passing equations boil down to element-wise operations |
|
and tensor contractions that we easily implement using the GPU friendly julia library Tullio.jl (Abbott |
|
et al., 2021). For a layer of input and output size N and considering a batch-size of B, the time |
|
complexity of a forth-and-back iteration is O(N [2]B) for all message passing algorithms (BP, BPI, MF, |
|
and AMP), the same as SGD. The prefactor varies and it is generally larger than SGD (see Appendix |
|
B.9). Also, time complexity for message passing is proportional to τmax (which we typically set to |
|
1). We provide our implementation in the GitHub repo anonymous. |
|
|
|
**Algorithm 1: BP for deep neural networks** |
|
// Message passing used in the PasP Eq. 3 to approximate. |
|
|
|
// the mini-batch posterior. |
|
// Here we specifically refer to BP updates. |
|
// BPI, MF, and AMP updates take the same form but using |
|
// the rules in Appendix A.4, A.5, and A.7 respectively |
|
|
|
**1 Initialize messages.** |
|
|
|
**2 for τ = 1, . . . τmax do** |
|
|
|
// Forward Pass |
|
|
|
**3** **for l = 0, . . ., L do** |
|
|
|
**4** compute ˆx[ℓ], ∆[ℓ] using (6, 7) |
|
|
|
**5** compute m[ℓ], σ[ℓ] using (8, 9) |
|
|
|
**6** compute V[ℓ], ω[ℓ] using (10, 11) |
|
|
|
|
|
// Backward Pass |
|
|
|
**7** **for l = L, . . ., 0 do** |
|
|
|
**8** compute g[ℓ], Γ[ℓ] using (12, 13) |
|
|
|
**9** compute A[ℓ], B[ℓ] using (14, 15) |
|
|
|
**10** compute G[ℓ], H _[ℓ]_ using (16, 17) |
|
|
|
4 NUMERICAL RESULTS |
|
|
|
We implement our message passing algorithms on neural networks with continuous and binary |
|
weights and with binary activations. In our experiments we fix τmax = 1. We typically do not observe |
|
an increase in performance taking more steps, except for some specific cases and in particular for MF |
|
layers. We remark that for τmax = 1 the BP and the BPI equations are identical, so in most of the |
|
subsequent numerical results we will only investigate BP. |
|
|
|
We compare our algorithms with a SGD-based algorithm adapted to binary architectures (Hubara |
|
et al., 2016) which we call BinaryNet along the paper (see Appendix B.6 for details). Comparison |
|
of Bayesian predictions are with the gradient-based Expectation Backpropagation (EBP) algorithm |
|
(Soudry et al., 2014a), also able to deal with discrete weights and activations. In all architectures we |
|
avoid the use of bias terms and batch-normalization layers. |
|
|
|
We find that message-passing algorithms are able to train generic MLP architectures with varying numbers and sizes of hidden layers. As for the datasets, we are able to perform both binary classification |
|
and multi-class classification on standard computer vision datasets such as MNIST, Fashion-MNIST, |
|
and CIFAR-10. Since these datasets consist of 10 classes, for the binary classification task we divide |
|
each dataset in two classes (even vs odd). |
|
|
|
We report that message passing algorithms are able to solve these optimization problems with |
|
generalization performance comparable to or better than SGD-based algorithms. Some of the |
|
message passing algorithms (BP and AMP in particular) need fewer epochs to achieve low error than |
|
the ones required by SGD-based algorithms, even if adaptive methods like Adam are considered. |
|
Timings of our GPU implementations of message passing algorithms are competitive with SGD (see |
|
Appendix B.9). |
|
|
|
|
|
----- |
|
|
|
60 |
|
|
|
|
|
|Col1|BP train MF train BP test MF test| |
|
|---|---| |
|
||AMP train BinaryNet train AMP test BinaryNet test| |
|
||| |
|
||| |
|
||| |
|
|
|
|
|
20 40 60 80 100 |
|
|
|
4.1 EXPERIMENTS ACROSS ARCHITECTURES |
|
|
|
We select a specific task, multi-class classification on Fashion-MNIST, and we compare the message |
|
passing algorithms with BinaryNet for different choices of the architecture (i.e. we vary the number |
|
and the size of the hidden layers). In Fig.1 (Left) we present the learning curves for a MLP with |
|
3 hidden layers with 501 units with binary weights and activations. Similar results hold in our |
|
experiments with 2 or 3 hidden layers of 101, 501 or 1001 units and with batch sizes from 1 to from |
|
1024. The parameters used in our simulations are reported in Appendix B.3. Results on networks |
|
with continuous weights can be found in Fig.2 (Right). |
|
|
|
4.2 SPARSE LAYERS |
|
|
|
Since the BP algorithm has notoriously been successful on sparse graphs, we perform a straightforward implementation of pruning at initialization, i.e. we impose a random boolean mask on the |
|
weights that we keep fixed along the training. We call sparsity the fraction of zeroed weights. This |
|
kind of non-adaptive pruning is known to largely hinder learning (Frankle et al., 2021; Sung et al., |
|
2021). In the right panel of Fig. 1, we report results on sparse binary networks in which we train |
|
a MLP with 2 hidden layers of 101 units on the MNIST dataset. For reference, results on pruning |
|
quantized/binary networks can be found in Refs. (Han et al., 2016; Ardakani et al., 2017; Tung & |
|
Mori, 2018; Diffenderfer & Kailkhura, 2021). Experimenting with sparsity up to 90%, we observe |
|
that BP and MF perform better than BinaryNet. AMP struggles behind BinaryNet instead. |
|
|
|
25 100 |
|
|
|
BP train MF train |
|
|
|
95 |
|
|
|
BP test MF test |
|
|
|
20 AMP train BinaryNet train |
|
|
|
90 |
|
|
|
AMP test BinaryNet test |
|
|
|
85 |
|
|
|
15 |
|
|
|
80 |
|
|
|
BP test |
|
|
|
error (%) 10 Bayes BP test |
|
|
|
75 |
|
|
|
AMP test |
|
|
|
test accuracy (%) |
|
|
|
70 Bayes AMP test |
|
|
|
5 MF test |
|
|
|
65 Bayes MF test |
|
|
|
BinaryNet test |
|
|
|
0 |
|
|
|
epochs |
|
|
|
|
|
10 20 30 40 50 60 70 80 90 |
|
|
|
|Col1|Col2|Col3|Col4|Col5|Col6|Col7|Col8|Col9|Col10| |
|
|---|---|---|---|---|---|---|---|---|---| |
|
||||||||||| |
|
||||||||||| |
|
||||||||||| |
|
|||BP test Bayes|BP test||||||| |
|
|||AMP te Bayes|st AMP te|st|||||| |
|
|||MF tes Bayes|t MF tes|t|||||| |
|
|||Binary|Net test||||||| |
|
|
|
|
|
sparsity (%) |
|
|
|
|
|
Figure 1: (Left) Training curves of message passing algorithms compared with BinaryNet on the |
|
Fashion-MNIST dataset (multi-class classification) with a binary MLP with 3 hidden layers of 501 |
|
units. (Right) Final test accuracy when varying the layer’s sparsity in a binary MLP with 2 hidden |
|
layers of 101 units on the MNIST dataset (multi-class). In both panels the batch-size is 128 and |
|
curves are averaged over 5 realizations of the initial conditions (and sparsity pattern in the right |
|
panel). |
|
|
|
4.3 EXPERIMENTS ACROSS DATASETS |
|
|
|
We now fix the architecture, a MLP with 2 hidden layers of 501 neurons each with binary weights and |
|
activations. We vary the dataset, i.e. we test the BP-based algorithms on standard computer vision |
|
benchmark datasets such as MNIST, Fashion-MNIST and CIFAR-10, in both the multi-class and |
|
binary classification tasks. In Tab. 1 we report the final test errors obtained by the message passing |
|
algorithms compared to the BinaryNet baseline. See Appendix B.4 for the corresponding training |
|
errors and the parameters used in the simulations. We mention that while the test performance is |
|
mostly comparable, the train error tends to be lower for the message passing algorithms. |
|
|
|
|
|
----- |
|
|
|
|Col1|BinaryNet BP|Col3|AMP MF| |
|
|---|---|---|---| |
|
||sses) 1.3 ± 0.1 1.4 ± 0.2||1.4 ± 0.1 1.3 ± 0.| |
|
||ST (2 classes) 2.4 ± 0.1 2.3 ± 0.1||2.4 ± 0.1 2.3 ± 0.| |
|
||classes) 30.0 ± 0.3 31.4 ± 0.1||31.1 ± 0.3 31.1 ± 0.| |
|
||2.2 ± 0.1 2.6 ± 0.1||2.6 ± 0.1 2.3 ± 0.| |
|
||ST 12.0 ± 0.6 11.8 ± 0.3||11.9 ± 0.2 12.1 ± 0.| |
|
||59.0 ± 0.7 58.7 ± 0.3||58.5 ± 0.2 60.4 ± 1.| |
|
||on Fashion-MNIST of various algorith ts and activations. All algorithms are tra ard deviations are calculated over 5 ran IAN ERROR amework used as an estimator of the m ate Bayesian prediction, i.e. averagin We observe better generalization error f wing that the marginals retain useful ith the PasP mini-batch procedure (the e t this converges with difficulty in our te (as also confirmed by the local energy ompute can be considered as a local app cation on the MNIST dataset in Fig. 2, a tasets and architectures. We obtain the gle forward pass of the message passing osterior distribution does not concentra e to the prediction of a single configurat m a comparison of BP (point-wise an m Bayesian predictions, Expectation Ba mplementation details. y Weights 5 EBP Bayes BP BP BinaryNet 4 60 80 100 bayes EBP hs (%) 3 error 2 test 1||ms on a MLP with 2 hid ined with batch-size 12 dom initializations. ini-batch posterior mar g the pointwise predicti rom Bayesian predictio information. However, xact ones should be com sts). Since BP-based alg measure performed in A roximation of the full on nd we observe the same Bayesian prediction fro . To obtain good Bayes te too much, otherwise ion. d Bayesian) with SGD ckpropagation (Soudry Continuous Weights| |
|
|(%) 50 error 40 t|||SGD EBP| |
|
|tes 0 20 40 epoc|||bayes EBP| |
|
||||| |
|
||||| |
|
||||| |
|
|
|
|
|
20 40 60 80 100 |
|
|
|
epochs |
|
|
|
|
|
SGD BP |
|
EBP Bayes BP |
|
bayes EBP |
|
|
|
20 40 60 80 100 |
|
|
|
epochs |
|
|
|
|
|
Figure 2: (Left) Test error curves for Bayesian and point-wise predictions for a MLP with 2 hidden |
|
layers of 101 units on the 2-classes MNIST dataset. We report the results for (Left) binary and |
|
(Right) continuous weights. In both cases, we compare SGD, BP (point-wise and Bayesian) and EBP |
|
(point-wise and Bayesian). See Appendix B.3 for details. |
|
|
|
4.5 CONTINUAL LEARNING |
|
|
|
|
|
Given the high local entropy (i.e. the flatness) of the solutions found by the BP-based algorithms |
|
(see Appendix B.5), we perform additional tests in a classic setting, continual learning, where the |
|
|
|
|
|
----- |
|
|
|
possibility of locally rearranging the solutions while keeping low training error can be an advantage. |
|
When a deep network is trained sequentially on different tasks, it tends to forget exponentially fast |
|
previously seen tasks while learning new ones (McCloskey & Cohen, 1989; Robins, 1995; Fusi et al., |
|
2005). Recent work (Feng & Tu, 2021) has shown that searching for a flat region in the loss landscape |
|
can indeed help to prevent catastrophic forgetting. Several heuristics have been proposed to mitigate |
|
the problem (Kirkpatrick et al., 2017; Aljundi et al., 2018; Zenke et al., 2017; Laborieux et al., 2021) |
|
but all require specialized adjustments to the loss or the dynamics . |
|
|
|
Here we show instead that our message passing schemes are naturally prone to learn multiple tasks |
|
sequentially, mitigating the characteristic memory issues of the gradient-based schemes without the |
|
need for explicit modifications. As a prototypical experiment, we sequentially trained a multi-layer |
|
neural network on 6 different versions of the MNIST dataset, where the pixels of the images have |
|
been randomly permuted (Goodfellow et al., 2013), giving a fixed budget of 40 epochs on each task. |
|
We present the results for a two hidden layer neural network with 2001 units on each layer (see |
|
Appendix B.3 for details). As can be seen in Fig. 3, at the end of the training the BP algorithm is able |
|
to reach good generalization performances on all the tasks. We compared the BP performance with |
|
BinaryNet, which already performs better than SGD with continuous weights (see the discussion |
|
in Laborieux et al. (2021)). While our BP implementation is not competitive with ad-hoc techniques |
|
specifically designed for this problem, it beats non-specialized heuristics. Moreover, we believe that |
|
specialized approaches like the one of Laborieux et al. (2021) can be adapted to message passing as |
|
well. |
|
|
|
|BP Bi Bi Bi|naryNet lr= naryNet lr= naryNet lr=|0.1 1.0 10.0|Col4|Col5|Col6| |
|
|---|---|---|---|---|---| |
|
|
|
|
|
1 2 3 4 5 6 0 40 80 120 160 200 240 |
|
task # epochs |
|
|
|
100 |
|
|
|
90 |
|
|
|
80 |
|
|
|
70 |
|
|
|
60 |
|
|
|
50 |
|
|
|
40 BP |
|
|
|
30 BinaryNet lr=0.1 |
|
|
|
test accuracy (%) 20 BinaryNet lr=1.0 |
|
|
|
10 BinaryNet lr=10.0 |
|
|
|
0 1 2 3 4 5 6 |
|
|
|
task # |
|
|
|
|
|
Figure 3: Performance of BP and BinaryNet on the permuted MNIST task (see text) for a two hidden |
|
layer network with 2001 units on each layer and binary weights and activations. The model is trained |
|
sequentially on 6 different versions of the MNIST dataset (the tasks), where the pixels have been |
|
permuted. (Left) Test accuracy on each task after the network has been trained on all the tasks. |
|
(Right) Test accuracy on the first task as a function of the number of epochs. Points are averages over |
|
5 independent runs, shaded areas are errors on the mean. |
|
|
|
|
|
5 DISCUSSION AND CONCLUSIONS |
|
|
|
While successful in many fields, message passing algorithms, have notoriously struggled to scale |
|
to deep neural networks training problems. Here we have developed a class of fBP-based message |
|
passing algorithms and used them within an update scheme, Posterior-as-Prior (PasP), that makes it |
|
possible to train deep and wide multilayer perceptrons by message passing. |
|
|
|
We performed experiments binary activations and either binary or continuous weights. Future work |
|
should try to include different activations, biases, batch-normalization, and convolutional layers as |
|
well. Another interesting direction is the algorithmic computation of the (local) entropy of the model |
|
from the messages. |
|
|
|
Further theoretical work is needed for a more complete understanding of the robustness of our |
|
methods. Recent developments in message passing algorithms (Rangan et al., 2019) and related |
|
theoretical analysis (Goldt et al., 2020) could provide fruitful inspirations. While our algorithms |
|
can be used for approximate Bayesian inference, exact posterior calculation is still out of reach for |
|
message passing approaches and much technical work is needed in that direction. |
|
|
|
|
|
----- |
|
|
|
REFERENCES |
|
|
|
Michael Abbott, Dilum Aluthge, N3N5, Simeon Schaub, Carlo Lucibello, Chris Elrod, and Johnny |
|
[Chen. Tullio.jl julia package, 2021. URL https://github.com/mcabbott/Tullio.jl.](https://github.com/mcabbott/Tullio.jl) |
|
|
|
Rahaf Aljundi, Francesca Babiloni, Mohamed Elhoseiny, Marcus Rohrbach, and Tinne Tuytelaars. |
|
Memory aware synapses: Learning what (not) to forget. In Proceedings of the European Conference |
|
_on Computer Vision (ECCV), pp. 139–154, 2018._ |
|
|
|
Arash Ardakani, Carlo Condo, and Warren J. Gross. Sparsely-connected neural networks: Towards efficient VLSI implementation of deep neural networks. In 5th International Conference on Learning |
|
_Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Conference Track Proceedings._ |
|
[OpenReview.net, 2017. URL https://openreview.net/forum?id=r1fYuytex.](https://openreview.net/forum?id=r1fYuytex) |
|
|
|
Carlo Baldassi, Alfredo Braunstein, Nicolas Brunel, and Riccardo Zecchina. Efficient supervised |
|
learning in networks with binary synapses. Proceedings of the National Academy of Sciences, |
|
[104(26):11079–11084, 2007. ISSN 0027-8424. doi: 10.1073/pnas.0700324104. URL https:](https://www.pnas.org/content/104/26/11079) |
|
[//www.pnas.org/content/104/26/11079.](https://www.pnas.org/content/104/26/11079) |
|
|
|
Carlo Baldassi, Alessandro Ingrosso, Carlo Lucibello, Luca Saglietti, and Riccardo Zecchina. |
|
Subdominant dense clusters allow for simple learning and high computational performance |
|
in neural networks with discrete synapses. _Phys. Rev. Lett., 115:128101, Sep 2015._ |
|
[doi: 10.1103/PhysRevLett.115.128101. URL https://link.aps.org/doi/10.1103/](https://link.aps.org/doi/10.1103/PhysRevLett.115.128101) |
|
[PhysRevLett.115.128101.](https://link.aps.org/doi/10.1103/PhysRevLett.115.128101) |
|
|
|
Carlo Baldassi, Christian Borgs, Jennifer T. Chayes, Alessandro Ingrosso, Carlo Lucibello, Luca |
|
Saglietti, and Riccardo Zecchina. Unreasonable effectiveness of learning neural networks: From |
|
accessible states and robust ensembles to basic algorithmic schemes. Proceedings of the National |
|
_Academy of Sciences, 113(48):E7655–E7662, 2016a. ISSN 0027-8424. doi: 10.1073/pnas._ |
|
[1608103113. URL https://www.pnas.org/content/113/48/E7655.](https://www.pnas.org/content/113/48/E7655) |
|
|
|
Carlo Baldassi, Federica Gerace, Carlo Lucibello, Luca Saglietti, and Riccardo Zecchina. Learning |
|
may need only a few bits of synaptic precision. Phys. Rev. E, 93:052313, May 2016b. doi: 10. |
|
[1103/PhysRevE.93.052313. URL https://link.aps.org/doi/10.1103/PhysRevE.](https://link.aps.org/doi/10.1103/PhysRevE.93.052313) |
|
[93.052313.](https://link.aps.org/doi/10.1103/PhysRevE.93.052313) |
|
|
|
Carlo Baldassi, Fabrizio Pittorino, and Riccardo Zecchina. Shaping the learning landscape in neural |
|
networks around wide flat minima. Proceedings of the National Academy of Sciences, 117(1): |
|
[161–170, 2020. ISSN 0027-8424. doi: 10.1073/pnas.1908636117. URL https://www.pnas.](https://www.pnas.org/content/117/1/161) |
|
[org/content/117/1/161.](https://www.pnas.org/content/117/1/161) |
|
|
|
Jean Barbier, Florent Krzakala, Nicolas Macris, Léo Miolane, and Lenka Zdeborová. Optimal errors |
|
and phase transitions in high-dimensional generalized linear models. Proceedings of the National |
|
_Academy of Sciences, 116(12):5451–5460, 2019. ISSN 0027-8424. doi: 10.1073/pnas.1802705116._ |
|
[URL https://www.pnas.org/content/116/12/5451.](https://www.pnas.org/content/116/12/5451) |
|
|
|
Hans Bethe. Statistical theory of superlattices. Proc. R. Soc. A, 150:552, 1935. |
|
|
|
Alfredo Braunstein and Riccardo Zecchina. Learning by message passing in networks of discrete |
|
synapses. Phys. Rev. Lett., 96:030201, Jan 2006. doi: 10.1103/PhysRevLett.96.030201. URL |
|
[https://link.aps.org/doi/10.1103/PhysRevLett.96.030201.](https://link.aps.org/doi/10.1103/PhysRevLett.96.030201) |
|
|
|
Pratik Chaudhari, Anna Choromanska, Stefano Soatto, Yann LeCun, Carlo Baldassi, Christian Borgs, |
|
Jennifer T. Chayes, Levent Sagun, and Riccardo Zecchina. Entropy-sgd: Biasing gradient descent |
|
into wide valleys. In 5th International Conference on Learning Representations, ICLR 2017, |
|
_Toulon, France, April 24-26, 2017, Conference Track Proceedings. OpenReview.net, 2017. URL_ |
|
[https://openreview.net/forum?id=B1YfAfcgl.](https://openreview.net/forum?id=B1YfAfcgl) |
|
|
|
James Diffenderfer and Bhavya Kailkhura. Multi-prize lottery ticket hypothesis: Finding accurate |
|
binary neural networks by pruning a randomly weighted network. In International Confer_[ence on Learning Representations, 2021. URL https://openreview.net/forum?id=](https://openreview.net/forum?id=U_mat0b9iv)_ |
|
[U_mat0b9iv.](https://openreview.net/forum?id=U_mat0b9iv) |
|
|
|
|
|
----- |
|
|
|
Yu Feng and Yuhai Tu. The inverse variance–flatness relation in stochastic gradient descent is critical |
|
for finding flat minima. Proceedings of the National Academy of Sciences, 118(9), 2021. |
|
|
|
Alyson K Fletcher, Sundeep Rangan, and Philip Schniter. Inference in deep networks in high |
|
dimensions. In 2018 IEEE International Symposium on Information Theory (ISIT), pp. 1884–1888. |
|
IEEE, 2018. |
|
|
|
Jonathan Frankle, Gintare Karolina Dziugaite, Daniel Roy, and Michael Carbin. Pruning neural |
|
networks at initialization: Why are we missing the mark? In International Conference on Learning |
|
_[Representations, 2021. URL https://openreview.net/forum?id=Ig-VyQc-MLK.](https://openreview.net/forum?id=Ig-VyQc-MLK)_ |
|
|
|
Stefano Fusi, Patrick J Drew, and Larry F Abbott. Cascade models of synaptically stored memories. |
|
_Neuron, 45(4):599–611, 2005._ |
|
|
|
Marylou Gabrié. Mean-field inference methods for neural networks. Journal of Physics A: Mathe_matical and Theoretical, 53(22):223002, 2020._ |
|
|
|
Robert Gallager. Low-density parity-check codes. IRE Transactions on information theory, 8(1): |
|
21–28, 1962. |
|
|
|
Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry P Vetrov, and Andrew G Wilson. Loss |
|
surfaces, mode connectivity, and fast ensembling of dnns. In S. Bengio, H. Wallach, H. Larochelle, |
|
K. Grauman, N. Cesa-Bianchi, and R. Garnett (eds.), Advances in Neural Information Processing |
|
_[Systems, volume 31. Curran Associates, Inc., 2018. URL https://proceedings.neurips.](https://proceedings.neurips.cc/paper/2018/file/be3087e74e9100d4bc4c6268cdbe8456-Paper.pdf)_ |
|
[cc/paper/2018/file/be3087e74e9100d4bc4c6268cdbe8456-Paper.pdf.](https://proceedings.neurips.cc/paper/2018/file/be3087e74e9100d4bc4c6268cdbe8456-Paper.pdf) |
|
|
|
Xavier Glorot and Yoshua Bengio. Understanding the difficulty of training deep feedforward |
|
neural networks. In Yee Whye Teh and Mike Titterington (eds.), Proceedings of the Thirteenth |
|
_International Conference on Artificial Intelligence and Statistics, volume 9 of Proceedings of_ |
|
_Machine Learning Research, pp. 249–256, Chia Laguna Resort, Sardinia, Italy, 13–15 May 2010._ |
|
[PMLR. URL https://proceedings.mlr.press/v9/glorot10a.html.](https://proceedings.mlr.press/v9/glorot10a.html) |
|
|
|
Sebastian Goldt, Marc Mézard, Florent Krzakala, and Lenka Zdeborová. Modeling the influence of |
|
data structure on learning in neural networks: The hidden manifold model. Physical Review X, 10 |
|
(4):041044, 2020. |
|
|
|
Ian J Goodfellow, Mehdi Mirza, Da Xiao, Aaron Courville, and Yoshua Bengio. An empirical investigation of catastrophic forgetting in gradient-based neural networks. arXiv preprint arXiv:1312.6211, |
|
2013. |
|
|
|
Song Han, Huizi Mao, and William J. Dally. Deep compression: Compressing deep neural network |
|
with pruning, trained quantization and huffman coding. In Yoshua Bengio and Yann LeCun (eds.), |
|
_4th International Conference on Learning Representations, ICLR 2016, San Juan, Puerto Rico,_ |
|
_[May 2-4, 2016, Conference Track Proceedings, 2016. URL http://arxiv.org/abs/1510.](http://arxiv.org/abs/1510.00149)_ |
|
[00149.](http://arxiv.org/abs/1510.00149) |
|
|
|
José Miguel Hernández-Lobato and Ryan P. Adams. Probabilistic backpropagation for scalable |
|
learning of bayesian neural networks. In Proceedings of the 32nd International Conference on |
|
_International Conference on Machine Learning - Volume 37, ICML’15, pp. 1861–1869. JMLR.org,_ |
|
2015. |
|
|
|
Itay Hubara, Matthieu Courbariaux, Daniel Soudry, Ran El-Yaniv, and Yoshua Bengio. Binarized neural networks. In D. Lee, M. Sugiyama, U. Luxburg, I. Guyon, and R. Garnett (eds.), Advances in Neural Information Processing Systems, volume 29. Curran Asso[ciates, Inc., 2016. URL https://proceedings.neurips.cc/paper/2016/file/](https://proceedings.neurips.cc/paper/2016/file/d8330f857a17c53d217014ee776bfd50-Paper.pdf) |
|
[d8330f857a17c53d217014ee776bfd50-Paper.pdf.](https://proceedings.neurips.cc/paper/2016/file/d8330f857a17c53d217014ee776bfd50-Paper.pdf) |
|
|
|
Yiding Jiang, Behnam Neyshabur, Hossein Mobahi, Dilip Krishnan, and Samy Bengio. Fantastic generalization measures and where to find them. In International Conference on Learning |
|
_[Representations, 2020. URL https://openreview.net/forum?id=SJgIPJBFvH.](https://openreview.net/forum?id=SJgIPJBFvH)_ |
|
|
|
Yoshiyuki Kabashima, Florent Krzakala, Marc Mézard, Ayaka Sakata, and Lenka Zdeborová. Phase |
|
transitions and sample complexity in bayes-optimal matrix factorization. IEEE Transactions on |
|
_Information Theory, 62(7):4228–4265, 2016. doi: 10.1109/TIT.2016.2556702._ |
|
|
|
|
|
----- |
|
|
|
James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A |
|
Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka Grabska-Barwinska, et al. Overcoming |
|
catastrophic forgetting in neural networks. Proceedings of the national academy of sciences, 114 |
|
(13):3521–3526, 2017. |
|
|
|
Jonathan Kuck, Shuvam Chakraborty, Hao Tang, Rachel Luo, Jiaming Song, Ashish Sabharwal, and |
|
Stefano Ermon. Belief propagation neural networks. In H. Larochelle, M. Ranzato, R. Hadsell, |
|
M. F. Balcan, and H. Lin (eds.), Advances in Neural Information Processing Systems, volume 33, |
|
[pp. 667–678. Curran Associates, Inc., 2020. URL https://proceedings.neurips.cc/](https://proceedings.neurips.cc/paper/2020/file/07217414eb3fbe24d4e5b6cafb91ca18-Paper.pdf) |
|
[paper/2020/file/07217414eb3fbe24d4e5b6cafb91ca18-Paper.pdf.](https://proceedings.neurips.cc/paper/2020/file/07217414eb3fbe24d4e5b6cafb91ca18-Paper.pdf) |
|
|
|
Axel Laborieux, Maxence Ernoult, Tifenn Hirtzlin, and Damien Querlioz. Synaptic metaplasticity in binarized neural networks. Nature Communications, 12(1):2549, May 2021. ISSN |
|
2041-1723. doi: 10.1038/s41467-021-22768-y. [URL https://doi.org/10.1038/](https://doi.org/10.1038/s41467-021-22768-y) |
|
[s41467-021-22768-y.](https://doi.org/10.1038/s41467-021-22768-y) |
|
|
|
Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer, and Tom Goldstein. Visualizing the loss landscape of neural nets. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and |
|
R. Garnett (eds.), Advances in Neural Information Processing Systems, volume 31. Curran As[sociates, Inc., 2018. URL https://proceedings.neurips.cc/paper/2018/file/](https://proceedings.neurips.cc/paper/2018/file/a41b3bb3e6b050b6c9067c67f663b915-Paper.pdf) |
|
[a41b3bb3e6b050b6c9067c67f663b915-Paper.pdf.](https://proceedings.neurips.cc/paper/2018/file/a41b3bb3e6b050b6c9067c67f663b915-Paper.pdf) |
|
|
|
Antoine Maillard, Florent Krzakala, Marc Mézard, and Lenka Zdeborová. Perturbative construction |
|
of mean-field equations in extensive-rank matrix factorization and denoising. arXiv preprint |
|
_arXiv:2110.08775, 2021._ |
|
|
|
Andre Manoel, Florent Krzakala, Marc Mézard, and Lenka Zdeborová. Multi-layer generalized linear |
|
estimation. In 2017 IEEE International Symposium on Information Theory (ISIT), pp. 2098–2102, |
|
2017. doi: 10.1109/ISIT.2017.8006899. |
|
|
|
Michael McCloskey and Neal J Cohen. Catastrophic interference in connectionist networks: The |
|
sequential learning problem. In Psychology of learning and motivation, volume 24, pp. 109–165. |
|
Elsevier, 1989. |
|
|
|
Marc Mézard. Mean-field message-passing equations in the hopfield model and its generalizations. |
|
_Physical Review E, 95(2):022117, 2017._ |
|
|
|
Marc Mézard, Giorgio Parisi, and Miguel Angel Virasoro. Spin glass theory and beyond: An |
|
_Introduction to the Replica Method and Its Applications, volume 9. World Scientific Publishing_ |
|
Company, 1987. |
|
|
|
Thomas P. Minka. Expectation propagation for approximate bayesian inference. In Proceedings of |
|
_the Seventeenth Conference on Uncertainty in Artificial Intelligence, UAI’01, pp. 362–369, San_ |
|
Francisco, CA, USA, 2001. Morgan Kaufmann Publishers Inc. ISBN 1558608001. |
|
|
|
Marc Mézard and Andrea Montanari. Information, Physics, and Computation. Oxford University |
|
Press, Inc., USA, 2009. ISBN 019857083X. |
|
|
|
Jason T Parker, Philip Schniter, and Volkan Cevher. Bilinear generalized approximate message |
|
passing—part i: Derivation. IEEE Transactions on Signal Processing, 62(22):5839–5853, 2014. |
|
|
|
Judea Pearl. Reverend Bayes on inference engines: A distributed hierarchical approach. Cognitive |
|
Systems Laboratory, School of Engineering and Applied Science ..., 1982. |
|
|
|
R. Peierls. On ising’s model of ferromagnetism. Mathematical Proceedings of the Cambridge |
|
_Philosophical Society, 32(3):477–481, 1936. doi: 10.1017/S0305004100019174._ |
|
|
|
Fabrizio Pittorino, Carlo Lucibello, Christoph Feinauer, Gabriele Perugini, Carlo Baldassi, Elizaveta |
|
Demyanenko, and Riccardo Zecchina. Entropic gradient descent algorithms and wide flat minima. |
|
[In International Conference on Learning Representations, 2021. URL https://openreview.](https://openreview.net/forum?id=xjXg0bnoDmS) |
|
[net/forum?id=xjXg0bnoDmS.](https://openreview.net/forum?id=xjXg0bnoDmS) |
|
|
|
Sundeep Rangan, Philip Schniter, and Alyson K Fletcher. Vector approximate message passing. IEEE |
|
_Transactions on Information Theory, 65(10):6664–6684, 2019._ |
|
|
|
|
|
----- |
|
|
|
Rajesh P. N. Rao. Neural models of Bayesian belief propagation., pp. 239–267. Bayesian brain: Probabilistic approaches to neural coding. MIT Press, Cambridge, MA, US, 2007. ISBN 026204238X |
|
(Hardcover); 978-0-262-04238-3 (Hardcover). |
|
|
|
Anthony Robins. Catastrophic forgetting, rehearsal and pseudorehearsal. Connection Science, 7(2): |
|
123–146, 1995. |
|
|
|
Victor Garcia Satorras and Max Welling. Neural enhanced belief propagation on factor graphs. In |
|
_International Conference on Artificial Intelligence and Statistics, pp. 685–693. PMLR, 2021._ |
|
|
|
Daniel Soudry, Itay Hubara, and Ron Meir. Expectation backpropagation: Parameterfree training of multilayer neural networks with continuous or discrete weights. In |
|
Z. Ghahramani, M. Welling, C. Cortes, N. Lawrence, and K. Q. Weinberger (eds.), |
|
_Advances in Neural Information Processing Systems,_ volume 27. Curran Associates, |
|
Inc., 2014a. URL [https://proceedings.neurips.cc/paper/2014/file/](https://proceedings.neurips.cc/paper/2014/file/076a0c97d09cf1a0ec3e19c7f2529f2b-Paper.pdf) |
|
[076a0c97d09cf1a0ec3e19c7f2529f2b-Paper.pdf.](https://proceedings.neurips.cc/paper/2014/file/076a0c97d09cf1a0ec3e19c7f2529f2b-Paper.pdf) |
|
|
|
Daniel Soudry, Itay Hubara, and Ron Meir. Expectation backpropagation: Parameter-free training of |
|
multilayer neural networks with continuous or discrete weights. In NIPS, volume 1, pp. 2, 2014b. |
|
|
|
George Stamatescu, Federica Gerace, Carlo Lucibello, Ian Fuss, and Langford B. White. Critical |
|
[initialisation in continuous approximations of binary neural networks. 2020. URL https:](https://openreview.net/forum?id=rylmoxrFDH) |
|
[//openreview.net/forum?id=rylmoxrFDH.](https://openreview.net/forum?id=rylmoxrFDH) |
|
|
|
Yi-Lin Sung, Varun Nair, and Colin Raffel. Training neural networks with fixed sparse masks, 2021. |
|
|
|
Frederick Tung and Greg Mori. Clip-q: Deep network compression learning by in-parallel pruningquantization. In 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. |
|
7873–7882, 2018. doi: 10.1109/CVPR.2018.00821. |
|
|
|
Anqi Wu, Sebastian Nowozin, Edward Meeds, Richard E Turner, Jose Miguel Hernandez-Lobato, |
|
and Alexander L Gaunt. Deterministic variational inference for robust bayesian neural networks. |
|
_arXiv preprint arXiv:1810.03958, 2018._ |
|
|
|
Jonathan S. Yedidia, William T. Freeman, and Yair Weiss. Understanding Belief Propagation and Its |
|
_Generalizations, pp. 239–269. Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, 2003._ |
|
ISBN 1558608117. |
|
|
|
Lenka Zdeborová and Florent Krzakala. Statistical physics of inference: Thresholds and algorithms. |
|
_Advances in Physics, 65(5):453–552, 2016._ |
|
|
|
Friedemann Zenke, Ben Poole, and Surya Ganguli. Continual learning through synaptic intelligence. |
|
In International Conference on Machine Learning, pp. 3987–3995. PMLR, 2017. |
|
|
|
Qiuyun Zou, Haochuan Zhang, and Hongwen Yang. Multi-layer bilinear generalized approximate |
|
message passing. IEEE Transactions on Signal Processing, 69:4529–4543, 2021. doi: 10.1109/ |
|
TSP.2021.3100305. |
|
|
|
|
|
----- |
|
|
|
# Appendices |
|
|
|
CONTENTS |
|
|
|
**A BP-based message passing algorithms** **14** |
|
|
|
A.1 Preliminary considerations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 14 |
|
|
|
A.2 Derivation of the BP equations . . . . . . . . . . . . . . . . . . . . . . . . . . . . 15 |
|
|
|
A.3 BP equations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 18 |
|
|
|
A.4 BPI equations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 19 |
|
|
|
A.5 MF equations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 20 |
|
|
|
A.6 Derivation of the AMP equations . . . . . . . . . . . . . . . . . . . . . . . . . . . 21 |
|
|
|
A.7 AMP equations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 22 |
|
|
|
A.8 Activation Functions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 23 |
|
|
|
A.9 The ArgMax layer . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 23 |
|
|
|
**B** **Experimental details** **25** |
|
|
|
B.1 Hyper-parameters of the BP-based scheme . . . . . . . . . . . . . . . . . . . . . . 25 |
|
|
|
B.2 Damping scheme for the message passing . . . . . . . . . . . . . . . . . . . . . . 25 |
|
|
|
B.3 Architectures . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 26 |
|
|
|
B.4 Varying the dataset . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 27 |
|
|
|
B.5 Local energy . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 27 |
|
|
|
B.6 SGD implementation (BinaryNet) . . . . . . . . . . . . . . . . . . . . . . . . . . 28 |
|
|
|
B.7 EBP implementation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 28 |
|
|
|
B.8 Unit polarization and overlaps . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29 |
|
|
|
B.9 Computational performance: varying batch-size . . . . . . . . . . . . . . . . . . . 30 |
|
|
|
A BP-BASED MESSAGE PASSING ALGORITHMS |
|
|
|
A.1 PRELIMINARY CONSIDERATIONS |
|
|
|
Given a mini-batch = (xn, yn) _n, the factor graph defined by Eqs. (1, 2, 18) is explicitly written_ |
|
_B_ _{_ _}_ |
|
as: |
|
|
|
_L_ |
|
|
|
_P_ ( _, x[1:][L]_ _, θ)_ _P_ _[ℓ][+1]_ _x[ℓ]kn[+1]_ _Wki[ℓ]_ _[x][ℓ]in_ _qθ(Wki[ℓ]_ [)][,] (18) |
|
_W_ _| B_ _∝_ |
|
|
|
_ℓ=0_ _k,n_ _i_ _k,i,ℓ_ |
|
|
|
Y Y X ! Y |
|
|
|
where x[0]n [=][ x][n][,][ x]n[L][+1] = yn. The derivation of the BP equations for this model is straightforward |
|
albeit lengthy and involved. It is obtained following the steps presented in multiple papers, books, |
|
and reviews, see for instance (Mézard & Montanari, 2009; Zdeborová & Krzakala, 2016; Mézard, |
|
2017), although it has not been attempted before in deep neural networks. It should be noted that a |
|
(common) approximation that we take here with respect to the standard BP scheme, is that messages |
|
are assumed to be Gaussian distributed and therefore parameterized by their mean and variance. This |
|
goes by the name of relaxed belied propagation (rBP), just referred to as BP throughout the paper. |
|
|
|
We derive the BP equations in A.2 and present them all together in A.3. From BP, we derive other 3 |
|
message passing algorithms useful for the deep network training setting, all of which are well known |
|
to the literature: BP-Inspired (BPI) message passing A.4, mean-field (MF) A.5, and approximate |
|
|
|
|
|
----- |
|
|
|
message passing (AMP) A.7. The AMP derivation is the more involved and given in A.6. In all |
|
these cases, message updates can be divided in a forward pass and a backward pass, as also done in |
|
Fletcher et al. (2018) in a multi-layer inference setting. The BP algorithm is compactly reported in |
|
Algorithm 1. |
|
|
|
In our notation, ℓ denotes the layer index, τ the BP iteration index, k an output neuron index, i an |
|
input neuron index, and n a sample index. |
|
|
|
We report below, for convenience, some of the considerations also present in the main text. |
|
|
|
**Meaning of messages.** All the messages involved in the message passing equations can be understood in terms of cavity marginals or full marginals (as mentioned in the introduction BP is also |
|
known as the Cavity Method, see Mézard & Montanari (2009)). Of particular relevance are the |
|
quantities m[ℓ]ki [and][ σ]ki[ℓ] [, denoting the mean and variance of the weights][ W][ ℓ]ki[. The quantities][ ˆ]x[ℓ]in [and] |
|
∆[ℓ]in [instead denote mean and variance of the][ i][-th neuron’s activation in layer][ ℓ] [in correspondence of] |
|
an input xn. |
|
|
|
**Scalar free energies.** All message passing schemes can be expressed using the following scalar |
|
functions, corresponding to single neuron and single weight effective free-energies respectively: |
|
|
|
_ϕ[ℓ](B, A, ω, V ) = log_ dx dz e[−] [1]2 _[Ax][2][+][Bx]_ _P_ _[ℓ]_ (x | z) e[−] [(][ω]2[−]V[z][)2] _,_ (19) |
|
Z |
|
|
|
|
|
_ψ(H, G, θ) = log_ dw e[−] [1]2 _[G][2][w][2][+][Hw]_ _qθ(w)._ (20) |
|
Z |
|
|
|
These free energies will naturally arise in the derivation of the BP equations in Appendix A.2. For |
|
the last layer, the neuron function has to be slightly modified: |
|
|
|
_ϕ[L][+1](y, ω, V ) = log_ dz P _[L][+1]_ (y | z) e[−] [(][ω]2[−]V[z][)2] _._ (21) |
|
Z |
|
|
|
Notice that for common deterministic activations such as ReLU and sign, the function ϕ has |
|
analytic and smooth expressions that we give in Appendix A.8. Same goes for ψ when qθ(w) |
|
is Gaussian (continuous weights) or a mixture of atoms (discrete weights). At the last layer |
|
we impose P _[L][+1](y|z) = I(y = sign(z)) in binary classification tasks. For multi-class clas-_ |
|
sification instead, we have to adapt the formalism to vectorial pre-activations z and assume |
|
_P_ _[L][+1](y|z) = I(y = arg max(z)) (see Appendix A.9). While in our experiments we use hard_ |
|
constraints for the final output, therefore solving a constraint satisfaction problem, it would be interesting to also consider generic loss functions. That would require minimal changes to our formalism, |
|
but this is beyond the scope of our work. |
|
|
|
**Binary weights.** In our experiments we use ±1 weights in each layer. Therefore each marginal can |
|
be parameterized by a single number and our prior/posterior takes the form |
|
|
|
_qθ(Wki[ℓ]_ [)][ ∝] _[e][θ]ki[ℓ]_ _[W][ ℓ]ki_ (22) |
|
|
|
The effective free energy function Eq. 20 becomes |
|
|
|
_ψ(H, G, θki[ℓ]_ [) = log 2 cosh(][H][ +][ θ]ki[ℓ] [)] (23) |
|
|
|
and the messages G can be dropped from the message passing. |
|
|
|
**Start and end of message passing.** At the beginning of a new PasP iteration t, we reset the |
|
messages to zero and run message passing for τmax iterations. We then compute the new prior |
|
_qθt+1_ ( ) from the posterior given by the message passing iterations. |
|
_W_ |
|
|
|
A.2 DERIVATION OF THE BP EQUATIONS |
|
|
|
In order to derive the BP equations, we start with the following portion of the factor graph reported in |
|
Eq. 18 in the main text, describing the contribution of a single data example in the inner loop of the |
|
PasP updates: |
|
|
|
|
|
----- |
|
|
|
_P_ _[ℓ][+1]_ |
|
|
|
|
|
_x[ℓ]kn[+1]_ |
|
|
|
|
|
_Wki[ℓ]_ _[x][ℓ]in_ |
|
|
|
|
|
where x[0]n [=][ x][n][,][ x]n[L][+1] = yn. (24) |
|
|
|
|
|
_ℓ=0_ _k_ _i_ |
|
|
|
where we recall that the quantity x[ℓ]kn [corresponds to the activation of neuron][ k][ in layer][ ℓ] [in corre-] |
|
spondence of the input example n. |
|
|
|
Let us start by analyzing the single factor: |
|
|
|
|
|
_P_ _[ℓ][+1]_ |
|
|
|
|
|
_x[ℓ]kn[+1]_ |
|
|
|
|
|
_Wki[ℓ]_ _[x][ℓ]in_ |
|
|
|
|
|
(25) |
|
|
|
|
|
We refer to messages that travel from input to output in the factor graph as upgoing or upwards |
|
messages, while to the ones that travel from output to input as downgoing or backwards messages. |
|
|
|
**Factor-to-variable-W messages** The factor-to-variable-W messages read: |
|
|
|
|
|
_νˆkn[ℓ][+1]_ _ki[(][W][ ℓ]ki[)][ ∝]_ _dνki[ℓ]_ _[′]_ _n[(][W][ ℓ]ki[′]_ [)] |
|
_→_ _→_ |
|
Z Yi[′]≠ _i_ |
|
|
|
|
|
_dνi[ℓ][′]n_ _k[(][x]i[ℓ][′]n[)][ dν][↓][(][x][ℓ]kn[+1][)][ P][ ℓ][+1]_ |
|
_→_ |
|
_i[′]_ |
|
|
|
Y |
|
|
|
|
|
_x[ℓ]kn[+1]_ |
|
|
|
|
|
_Wki[ℓ]_ _[′]_ _[x]i[ℓ][′]n_ |
|
_i[′]_ ! |
|
|
|
X |
|
|
|
(26) |
|
|
|
|
|
where ν denotes the messages travelling downwards (from output to input) in the factor graph. |
|
_↓_ |
|
|
|
We denote the means and variances of the incoming messages respectively with m[ℓ]ki _n[,][ ˆ]x[ℓ]in_ _k_ [and] |
|
_→_ _→_ |
|
_σki[ℓ]_ _n[,][ ∆][ℓ]in_ _k[:]_ |
|
_→_ _→_ |
|
|
|
_m[ℓ]ki_ _n_ [=] _dνki[ℓ]_ _n[(][W][ ℓ]ki[)][ W][ ℓ]ki_ (27) |
|
_→_ _→_ |
|
Z |
|
|
|
_σki[ℓ]_ _n_ [=] _dνki[ℓ]_ _n[(][W][ ℓ]ki[)]_ _Wki[ℓ]_ _ki_ _n_ 2 (28) |
|
_→_ _→_ _[−]_ _[m][ℓ]_ _→_ |
|
Z |
|
|