Page 1

Learning Message-Passing Inference Machines for Structured Prediction

St´ ephane RossDaniel Munoz

The Robotics Institute, Carnegie Mellon University

stephaneross@cmu.edu, {dmunoz, hebert, dbagnell}@ri.cmu.edu

Martial HebertJ. Andrew Bagnell

Abstract

Nearly every structured prediction problem in computer

visionrequiresapproximateinferenceduetolargeandcom-

plex dependencies among output labels. While graphical

models provide a clean separation between modeling and

inference, learning these models with approximate infer-

ence is not well understood. Furthermore, even if a good

model is learned, predictions are often inaccurate due to

approximations. In this work, instead of performing infer-

ence over a graphical model, we instead consider the in-

ference procedure as a composition of predictors. Specif-

ically, we focus on message-passing algorithms, such as

Belief Propagation, and show how they can be viewed as

procedures that sequentially predict label distributions at

each node over a graph. Given labeled graphs, we can then

train the sequence of predictors to output the correct label-

ings. The result no longer corresponds to a graphical model

but simply defines an inference procedure, with strong the-

oretical properties, that can be used to classify new graphs.

We demonstrate the scalability and efficacy of our approach

on 3D point cloud classification and 3D surface estimation

from single images.

1. Introduction

Probabilistic graphical models, such as Conditional Ran-

dom Fields (CRFs) [11], have proven to be a remarkably

successful tool for structured prediction that, in principle,

provide a clean separation between modeling and infer-

ence. However, exact inference for problems in computer

vision (e.g., Fig. 1) is often intractable due to a large num-

ber of dependent output variables (e.g., one for each (su-

per)pixel). In order to cope with this problem, inference

inevitably relies on approximate methods: Monte-Carlo,

loopy belief propagation, graph-cuts, and variational meth-

ods [16, 2, 23]. Unfortunately, learning these models with

approximate inference is not well understood [9, 6]. Addi-

tionally, it has been observed that it is important to tie the

graphical model to the specific approximate inference pro-

cedure used at test time to obtain better predictions [10, 22].

Figure 1: Applications of structured prediction in computer

vision. Left: 3D surface layout estimation. Right: 3D point

cloud classification.

When the learned graphical model is tied to the inference

procedure, the graphical model is not necessarily a good

probabilistic model of the data but simply a parametrization

of the inference procedure that yields the best predictions

within the “class” of inference procedures considered. This

raises an important question: if the ultimate goal is to obtain

the bestpredictions, then why isthe inference procedure op-

timized indirectly by learning a graphical model? Perhaps

it is possible to optimize the inference procedure more di-

rectly, without building an explicit probabilistic model over

the data.

Some recent approaches [4, 21] eschew the probabilistic

graphical model entirely with notable successes. However,

we would ideally like to have the best of both worlds: the

proven success of error-correcting iterative decoding meth-

ods along with a tight relationship between learning and in-

ference. To enable this combination, we propose an alter-

nate view of the approximate inference process as a long

sequence of computational modules to optimize [1] such

that the sequence results in correct predictions. We focus on

message-passing inference procedures, such as Belief Prop-

agation, which compute marginal distributions over out-

put variables by iteratively visiting all nodes in the graph

and passing messages to neighbors which consist of “cav-

ity marginals”, i.e., a series of marginals with the effect of

each neighbor removed. Message-passing inference can be

viewed as a function applied iteratively to each variable that

takes as input local observations/features and local compu-

2737

Page 2

tations on the graph (messages) and provides as output the

intermediate messages/marginals. Hence, such a procedure

can be trained directly by training a predictor which pre-

dicts a current variable’s marginal1given local features and

a subset of neighbors’ cavity marginals. By training such a

predictor, there is no need to have a probabilistic graphical

model of the data, and there need not be any probabilistic

model that corresponds to the computations performed by

the predictor. The inference procedure is instead thought of

as a black box function that is trained to yield correct pre-

dictions. This is analogous to many discriminative learning

methods; it may be easier to simply discriminate between

classes than build a generative probabilistic model of them.

There are a number of advantages to doing message-

passing inference as a sequence of predictions. Considering

different classes of predictors allows one to obtain entirely

differentclassesofinferenceproceduresthatperformdiffer-

ent approximations. The level of approximation and com-

putational complexity of the inference can be controlled in

part by considering more or less complex classes of predic-

tors. This allows one to naturally trade-off accuracy versus

speed of inference in real-time settings. Furthermore, in

contrast with most approaches to learning inference, we are

able to provide rigorous reduction-style guarantees [18] on

the performance of the resulting inference procedure.

Training such a predictor, however, is non-trivial as

the interdependencies in the sequence of predictions make

global optimization difficult. Building from success in deep

learning, a first key technique we use is to leverage in-

formation local to modules to aid learning [1]. Because

each module’s prediction in the sequence corresponds to

the computation of a particular variable’s marginal, we ex-

ploit this information and try to make these intermediate

inference steps match the ideal output in our training data

(i.e., a marginal with probability 1 to the correct class).

To provide good guarantees and performance in practice in

this non-i.i.d. setting (as predictions are interdependent),

we also leverage key iterative training methods developed

in prior work for imitation learning and structured predic-

tion [17, 18, 4]. These techniques allow us to iteratively

train probabilistic predictors that predict the ideal variable

marginals under the distribution of inputs the learned pre-

dictors induce during inference. Optionally, we may refine

performance using an optimization procedure such as back-

propagation through the sequence of predictions in order to

improve the overall objective (i.e., minimize loss of the final

marginals).

In the next section, we first review message-passing in-

ference for graphical models. We then present our approach

for training message-passing inference procedures in Sec.

3. In Sec. 4, we demonstrate the efficacy of our proposed

1For lack of a better term, we will use marginal throughout to mean a

distribution over one variable’s labels.

approachbydemonstratingstate-of-the-artresultsonalarge

3D point cloud classification task as well as in estimating

geometry from a single image.

2. Graphical Models and Message-Passing

Graphical models provide a natural way of encoding

spatial dependencies and interactions between neighboring

sites (pixels, superpixels, segments, etc.) in many computer

vision applications such as scene labeling (Fig. 1). A graph-

ical model represents a joint (conditional) distribution over

labelings of each site (node), via a factor graph (a bipartite

graph between output variables and factors) defined by a set

of variable nodes (sites) V , a set of factor nodes (potentials)

F and a set of edges E between them:

?

where X,Y are the vectors of all observed features and out-

put labels respectively, xf the features related to factor F

and yfthe vector of labels for each node connected to fac-

tor F. A typical graphical model will have node potentials

(factors connected to a single variable node) and pairwise

potentials (factors connected between 2 nodes). It is also

possible to consider higher order potentials by having a fac-

tor connecting many nodes (e.g., cluster/segment potentials

as in [14]). Training a graphical model is achieved by op-

timizing the potentials φf on an objective function (e.g.,

margin, pseudo-likelihood, etc.) defined over training data.

To classify a new scene, an (approximate) inference pro-

cedure estimates the most likely joint label assignment or

marginals over labels at each node.

Loopy Belief Propagation (BP) [16] is perhaps the

canonical message-passing algorithm for performing (ap-

proximate) inference in graphical models. Let Nvbe the set

of factors connected to variable v, N−f

connected to v except factor f, Nfthe set of variables con-

nected to factor f and N−v

f

the set of variables connected

to f except variable v. At a variable v ∈ V , BP sends a

message mvfto each factor f in Nv:

?

where mvf(yv) denotes the value of the message for as-

signment yvto variable v. At a factor f ∈ F, BP sends a

message mfvto each variable v in Nf:

?

P(Y |X) ∝

f∈F

φf(xf,yf),

v

the set of factors

mvf(yv) ∝

f?∈N−f

v

mf?v(yv),

mfv(yv) ∝

y?

f|y?

v=yv

φf(y?

f,xf)

?

v?∈N−v

f

mv?f(y?

v?),

where y?

y?

potential function associated to factor f which depends on

fis an assignment to all variables v?connected to f,

v? is the particular assignment to v?(in y?

f), and φfis the

2738

Page 3

y?

CRF). Finally the marginal of variable v is obtained as:

fand potentially other observed features xf (e.g., in the

P(v = yv) ∝

?

f∈Nv

mfv(yv).

The messages in BP can be sent synchronously (i.e., all

messages over the graph are computed before they are sent

to their neighbors) or asynchronously (i.e., by sending the

message to the neighbor immediately).

ing asynchronously, BP usually starts at a random variable

node, with messages initialized uniformly, and then pro-

ceeds iteratively through the factor graph by visiting vari-

ables and factors in a breath-first-search manner (forward

and then in backward/reverse order) several times or until

convergence. The final marginals at each variable are com-

puted using the last equation. Asynchronous message pass-

ing often allows faster convergence and methods such as

Residual BP [5] have been developed to achieve still faster

convergence by prioritizing the messages to compute.

When proceed-

2.1. Understanding Message Passing as

Sequential Probabilistic Classification

By definition of P(v = yv), the message mvf can be

interpreted as the marginal of variable v when the factor f

(and its influence) is removed from the graph. This is often

referred as the cavity method in statistical mechanics [3]

and mvfare known as cavity marginals. By expanding the

definition of mvf, we can see that it may depend only on

the messages mv?f? sent by all variables v?connected to v

by a factor f??= f:

?

mvf(yv) ∝

f?∈N−f

v

?

y?

f?|y?

v=yv

φf?(y?

f?,xf?)

?

v?∈N−v

f?

mv?f?(y?

v?).

(1)

Hence the messages mvfleaving a variable v toward a fac-

tor f in BP can be thought as the classification of the cur-

rent variable v (marginal distribution over classes) using the

cavity marginals mv?f? sent by variables v?connected to v

through a factor f??= f. In this view, BP is iteratively

classifying the variables in the graph by performing a se-

quence of classifications (marginals) for each message leav-

ing a variable. The final marginals P(v = yv) are then ob-

tained by classifying v using all messages from all variables

v?connected to v through some factor f ∈ Nv.

An example of how BP unrolls to a sequence of interde-

pendent local classifications is shown in Fig. 2 for a sim-

ple graph. In this view, the job of the predictor is not only

to emulate the computation going on during BP at variable

nodes, but also emulate the computations going on at all

the factors connected to the variable which it is not send-

ing the message to, as shown in Fig. 3. During inference

BP effectively employs a probabilistic predictor that has the

forminEquation1, wheretheinputsarethemessagesm?

v?f?

A

B

C

1

2

3

A1

A2

B3 C2

C3

B1

A1

A2

A

B3

B

C

Figure 2: Depiction of how BP unrolls into a sequence of

predictions for 3 passes on the graph on the left with 3 vari-

ables (A,B,C) and 3 factors (1,2,3), starting at A. Sequence

of predictions on the right, where e.g., A1 denotes the pre-

diction (message) of A sent to factor 1, while the output (fi-

nal marginals) are in gray and denoted by the corresponding

variable letter. Input arrows indicate the previous outputs

that are used in the computation of each message.

Input

Message

A

B

D

C

1

2

3

Classifier

Input

Message

Output

Message

(a)

Input

Message

A

B

D

C

1

2

3

Classifier

Input

Message

Input

Message

Output

Prediction

(b)

Figure 3: Depiction of the computations that the predictor

represents in BP for (a) a message to a neighboring factor

and (b) the final marginal of a variable outputed by BP.

and local observed features xf?. Training graphical models

can be understood as training a message-passing algorithm

with a particular class of predictors defined by Equation 1,

which have as parameters the potential functions φf. Under

this general view, there is no reason to restrict attention to

only predictors of the form of Equation 1. We now have

the possibility of using different classes of predictors (e.g.,

Logistic Regression, Boosted Trees, Random Forests, etc.)

whose inductive bias may more efficiently represent inter-

actions between neighboring variables or in some cases be

more compact and faster to compute, which is important in

real-time settings.

Many other techniques for approximate inference have

been framed in message-passing form. Tree-Weighted BP

[23] and convergent variants follow a similar pattern to BP

asdescribedabovebutchangethespecificformofmessages

sent to provide stronger performance and convergence guar-

antees. These can also be interpreted as performing a se-

quence of probabilistic classifications, but using a different

form of predictors. The classical “mean-field” (and more

generally variational methods [15]) method is easily framed

as a simpler message passing strategy where, instead of

2739

Page 4

cavity marginals, algorithms pass around marginals or ex-

pected sufficient statistics which is usually more efficient

but obtains lower performance then cavity message passing

[15]. We also consider training such mean-field inference

approach in the experiments.

3. Learning Message-Passing Inference

Training the cavity marginals’ predictors in the deep in-

ference network described above remains a non-trivial task.

As we saw in Fig. 2, the sequence of predictions forms a

large network where a predictor is applied at each node in

this network, similarly to a deep neural network. In gen-

eral, minimizing the loss of the output of such a network

is difficult since it is a non-convex optimization problem,

because the outputs of previous classifications are used as

input for following classifications. However here there are

several differences that make this training an easier problem

than training general networks. First, the number of param-

eters is small as we assume that the same predictor is used

at every node in this large network. Additionally, we can

exploit local sources of information (i.e., the variables tar-

get labels) to train the “hidden layer” nodes of this network.

Because each node corresponds to the computation of a par-

ticular variable’s marginal, we can always try to make these

marginals match the ideal output in our training data (i.e., a

marginal with probability 1 to the correct class).

Hence our general strategy for optimizing the message-

passing procedure will be to first use the local information

to train a predictor (or sequence of predictors) that predicts

the ideal variable marginals (messages) under the distribu-

tionofinputsitencountersduringtheinferenceprocess. We

refer to this step as local training. For cases where we train

a differentiable predictor, we can use the local training pro-

cedure to obtain a good starting point, and seek to optimize

the global non-convex objective (i.e., minimize logistic loss

only on the final marginals) using a descent procedure (i.e.,

back-propagation through this large network). We refer to

this second step as global training. The local training step

is still non-trivial as it corresponds to a non-i.i.d. super-

vised learning problem, i.e., previous classifications in the

networkinfluencethefutureinputsthatthepredictoriseval-

uated on. As all statistical learning approaches assume i.i.d.

data, it has been shown that typical supervised learning ap-

proaches have poor performance guarantees in this setting.

Fortunately, recent work [17, 18, 4] have presented itera-

tive training algorithms that can provide good guarantees.

We leverage these techniques below and present how these

techniques can be used in our setting.

3.1. Local Training for Synchronous Message-

Passing

In synchronous message-passing, messages from nodes

to their neighbors are only sent once all messages at each

for n = 1 to N do

Use h1:n−1to perform synchronous message-passing

on training graphs up to pass n.

Get dataset Dnof inputs encountered at pass n, with

the ideal marginals as target.

Train hnon Dnto minimize a loss (e.g., logistic).

end for

Return the sequence of predictors h1:N.

Algorithm 3.1: Forward Training Algorithm for Learning

Synchronous Message-Passing.

node have been computed. Our goal is to train a predictor

that performs well under the distribution of inputs (features

and messages) induced by the predictors used at previous

passes of inference. A strategy that we analyzed previously

in [17] to optimize such a network of modules is to simply

train a sequence of predictors, rather than a single predictor,

and train the predictors in sequence starting from the first

one. At iteration n, the previously learned predictors can

be used to generate inputs for training the nthpredictor in

the sequence. This guarantees that each predictor is trained

under the distribution of inputs it expects to see at test time.

In our synchronous message-passing scenario, this leads

to learning a different predictor for each inference pass. The

first predictor is trained to predict the ideal marginal at each

node given no information (uniform distribution messages)

from their neighbors. This predictor can then be used to per-

formafirstpassofinferenceonallthenodes. Thealgorithm

then iterates until a predictor for each inference pass has

been trained. At the nthiteration, the predictor is trained to

predict the ideal node marginals at the nthinference pass,

given the neighbors’ messages obtained after applying the

previously learned n − 1 predictors on the training graphs

(scenes) for n − 1 inference passes (Algorithm 3.1).

In [17, 18], we have showed that this forward training

procedureguaranteesthattheexpectedsumofthelossinthe

sequence is bounded by N¯ ?, where ¯ ? is the average true loss

of the learned predictors h1:N. In our scenario, we are only

concerned with the loss at the last inference pass. Unfor-

tunately, applying naively this guarantee would tell us that

the expected loss at the last pass is bounded by N¯ ? (e.g., in

the worst case where all the loss occurs at the last pass) and

would suggest that fewer inference passes is better (making

N small). However, for convex loss functions, such as the

logistic loss, simply averaging the output node marginals at

each pass, and using those average marginals as final out-

put, guarantees2achieving loss no worse than ¯ ?. Hence,

using those average marginals as final output enables using

an arbitrary number of passes to ensure we can effectively

find the best decoding.

2If f is convex and ¯ p =

1

N

PN

i=1pi, then f(¯ p) ≤

1

N

PN

i=1f(pi).

2740

Page 5

Some recent work is related to our approach.

demonstrates that constrained simple classification can pro-

vide good performance in NLP applications. The technique

of [21] can be understood as using forward training on a

synchronous message passing using only marginals, simi-

lar to mean-field inference. Similarly, from our point of

view, [13] implements a ”half-pass” of hierarchical mean-

field message passing by descending once down a hierarchy

making contextual predictions. We demonstrate in our ex-

periments the benefits of enabling more general (BP-style)

message passing.

[19]

3.2. Local Training for Asynchronous Message-

Passing

In asynchronous message-passing, messages from nodes

to their neighbors are sent immediately. This creates a long

sequence of dependent messages that grows with the num-

ber of nodes, in addition to the number of inference passes.

Hence the previous forward training procedure is imprac-

tical in this case for large graphs, as it requires training a

large number of predictors. Fortunately, an iterative ap-

proach called Dataset Aggregation (DAgger) [18] that we

developed in prior work can train a single predictor to pro-

duce all predictions in the sequence and still guarantees

good performance on its induced distribution of inputs over

the sequence. For our asynchronous message-passing set-

ting, DAgger proceeds as follows. Initially inference is per-

formed on the training graphs by using the ideal marginals

from the training data to classify each node and generate

a first training distribution of inputs. The dataset of inputs

encountered during inference and target ideal marginals at

each node is used to learn a first predictor. Then the process

keeps iterating by using the previously learned predictor to

perform inference on the training graphs and generate a new

dataset of encountered inputs during inference, with the as-

sociated ideal marginals. This new dataset is aggregated

to the previous one and a new predictor is trained on this

aggregated dataset (i.e., containing all data collected so far

over all iterations of the algorithm). This algorithm is sum-

marized in Algorithm 3.2. [18] showed that for strongly

convex losses, such as regularized logistic loss, this algo-

rithm has the following guarantee:

Theorem 3.1. [18] There exists a predictor hnin the se-

quence h1:Nsuch that Ex∼dhn[?(x,hn)] ≤ ? +˜O(1

? = argminh∈H

dhis the inputs distribution induced by predictor h. This

theorem indicates that DAgger guarantees a predictor that,

when used during inference, performs nearly as well as

when classifying the aggregate dataset. Again, in our case

we can average the predictions made at each node over the

inference passes to guarantee such final predictions would

have an average loss bounded by ? +˜O(1

N), for

1

N

?N

i=1Ex∼dhi[?(x,h)].

N). To make the

Initialize D0← ∅, h0to return the ideal marginal on any

variable v in the training graph.

for n = 1 to N do

Use hn−1 to perform asynchonous message-passing

inference on training graphs.

GetdatasetD?

with their ideal marginal as target.

Aggregate dataset: Dn= Dn−1∪ D?

Train hnon Dnto minimize a loss (e.g., logistic).

end for

Return best hnon training or validation graphs.

nofinputsencounteredduringinference,

n.

Algorithm 3.2: DAgger Algorithm for Learning Asyn-

chronous Message-Passing.

factor˜O(1

the whole graph, we can choose N to be on the order of the

number of nodes in a graph. Though in practice, often much

smaller number of iterations (N ∈ [10,20]), is sufficient to

obtain good predictors under their induced distributions.

N)negligiblewhenlookingatthesumoflossover

3.3. Global Training via Back-Propagation

In both synchronous and asynchronous approaches, the

local training procedures provide rigorous performance

bounds on the loss of the final predictions; however, they

do not optimize it directly. If the predictors learned are

differentiable functions, a procedure like back-propagation

[12] make it possible to identify local optima of the objec-

tive (minimizing loss of the final marginals). As this op-

timization problem is non-convex and there are potentially

many local minima, it can be crucial to initialize this de-

scent procedure with a good starting point. The forward

training and DAgger algorithms provide such an initial-

ization. In our setting, Back-Propagation effectively uses

the current predictor (or sequence of) to do inference on a

training graph (forward propagation); then errors are back-

propagated through the network of classification by rewind-

ing the inference, successively computing derivatives of the

output error with respect to parameters and input messages.

4. Experiments: Scene Labeling

To demonstrate the efficacy of our approach, we com-

pared our performance to state-of-the-art algorithms on two

labeling problems from publicly available datasets: (1) 3D

point cloud classification from a laser scanner and (2) 3D

surface layout estimation from a single image.

4.1. Datasets

3D Point Cloud Classification. We evaluate on the 3D

point cloud dataset3used in [14]. This dataset consists of

17 full 3D laser scans (total of ∼1.6 million 3D points) of

3http://www.cs.cmu.edu/˜ vmr/datasets/oakland 3d/cvpr09/

2741