-
-
Notifications
You must be signed in to change notification settings - Fork 988
/
lda.py
166 lines (145 loc) · 6.55 KB
/
lda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
This example implements amortized Latent Dirichlet Allocation [1],
demonstrating how to marginalize out discrete assignment variables in a Pyro
model. This model and inference algorithm treat documents as vectors of
categorical variables (vectors of word ids), and collapses word-topic
assignments using Pyro's enumeration. We use PyTorch's reparametrized Gamma and
Dirichlet distributions [2], avoiding the need for Laplace approximations as in
[1]. Following [1] we use the Adam optimizer and clip gradients.
**References:**
[1] Akash Srivastava, Charles Sutton. ICLR 2017.
"Autoencoding Variational Inference for Topic Models"
https://arxiv.org/pdf/1703.01488.pdf
[2] Martin Jankowiak, Fritz Obermeyer. ICML 2018.
"Pathwise gradients beyond the reparametrization trick"
https://arxiv.org/pdf/1806.01851.pdf
"""
import argparse
import functools
import logging
import torch
from torch import nn
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO
from pyro.optim import ClippedAdam
logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)
# This is a fully generative model of a batch of documents.
# data is a [num_words_per_doc, num_documents] shaped array of word ids
# (specifically it is not a histogram). We assume in this simple example
# that all documents have the same number of words.
def model(data=None, args=None, batch_size=None):
# Globals.
with pyro.plate("topics", args.num_topics):
topic_weights = pyro.sample(
"topic_weights", dist.Gamma(1.0 / args.num_topics, 1.0)
)
topic_words = pyro.sample(
"topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words)
)
# Locals.
with pyro.plate("documents", args.num_docs) as ind:
if data is not None:
with pyro.util.ignore_jit_warnings():
assert data.shape == (args.num_words_per_doc, args.num_docs)
data = data[:, ind]
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("words", args.num_words_per_doc):
# The word_topics variable is marginalized out during inference,
# achieved by specifying infer={"enumerate": "parallel"} and using
# TraceEnum_ELBO for inference. Thus we can ignore this variable in
# the guide.
word_topics = pyro.sample(
"word_topics",
dist.Categorical(doc_topics),
infer={"enumerate": "parallel"},
)
data = pyro.sample(
"doc_words", dist.Categorical(topic_words[word_topics]), obs=data
)
return topic_weights, topic_words, data
# We will use amortized inference of the local topic variables, achieved by a
# multi-layer perceptron. We'll wrap the guide in an nn.Module.
def make_predictor(args):
layer_sizes = (
[args.num_words]
+ [int(s) for s in args.layer_sizes.split("-")]
+ [args.num_topics]
)
logging.info("Creating MLP with sizes {}".format(layer_sizes))
layers = []
for in_size, out_size in zip(layer_sizes, layer_sizes[1:]):
layer = nn.Linear(in_size, out_size)
layer.weight.data.normal_(0, 0.001)
layer.bias.data.normal_(0, 0.001)
layers.append(layer)
layers.append(nn.Sigmoid())
layers.append(nn.Softmax(dim=-1))
return nn.Sequential(*layers)
def parametrized_guide(predictor, data, args, batch_size=None):
# Use a conjugate guide for global variables.
topic_weights_posterior = pyro.param(
"topic_weights_posterior",
lambda: torch.ones(args.num_topics),
constraint=constraints.positive,
)
topic_words_posterior = pyro.param(
"topic_words_posterior",
lambda: torch.ones(args.num_topics, args.num_words),
constraint=constraints.greater_than(0.5),
)
with pyro.plate("topics", args.num_topics):
pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.0))
pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
# Use an amortized guide for local variables.
pyro.module("predictor", predictor)
with pyro.plate("documents", args.num_docs, batch_size) as ind:
data = data[:, ind]
# The neural network will operate on histograms rather than word
# index vectors, so we'll convert the raw data to a histogram.
counts = torch.zeros(args.num_words, ind.size(0)).scatter_add(
0, data, torch.ones(data.shape)
)
doc_topics = predictor(counts.transpose(0, 1))
pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))
def main(args):
logging.info("Generating data")
pyro.set_rng_seed(0)
pyro.clear_param_store()
# We can generate synthetic data directly by calling the model.
true_topic_weights, true_topic_words, data = model(args=args)
# We'll train using SVI.
logging.info("-" * 40)
logging.info("Training on {} documents".format(args.num_docs))
predictor = make_predictor(args)
guide = functools.partial(parametrized_guide, predictor)
Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
elbo = Elbo(max_plate_nesting=2)
optim = ClippedAdam({"lr": args.learning_rate})
svi = SVI(model, guide, optim, elbo)
logging.info("Step\tLoss")
for step in range(args.num_steps):
loss = svi.step(data, args=args, batch_size=args.batch_size)
if step % 10 == 0:
logging.info("{: >5d}\t{}".format(step, loss))
loss = elbo.loss(model, guide, data, args=args)
logging.info("final loss = {}".format(loss))
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="Amortized Latent Dirichlet Allocation"
)
parser.add_argument("-t", "--num-topics", default=8, type=int)
parser.add_argument("-w", "--num-words", default=1024, type=int)
parser.add_argument("-d", "--num-docs", default=1000, type=int)
parser.add_argument("-wd", "--num-words-per-doc", default=64, type=int)
parser.add_argument("-n", "--num-steps", default=1000, type=int)
parser.add_argument("-l", "--layer-sizes", default="100-100")
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
parser.add_argument("-b", "--batch-size", default=32, type=int)
parser.add_argument("--jit", action="store_true")
args = parser.parse_args()
main(args)