Skip to content

Commit

Permalink
Sampling functions for the MuE/missing data discrete HMM. (#2898)
Browse files Browse the repository at this point in the history
  • Loading branch information
EWeinstein authored Jul 22, 2021
1 parent d9c45e7 commit b704945
Show file tree
Hide file tree
Showing 8 changed files with 773 additions and 45 deletions.
24 changes: 17 additions & 7 deletions examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
Reference:
[1] E. N. Weinstein, D. S. Marks (2021)
"Generative probabilistic biological sequence models that account for
mutational variability"
"A structured observation distribution for generative biological sequence
prediction and forecasting"
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf
"""

Expand Down Expand Up @@ -62,10 +62,10 @@ def generate_data(small_test, include_stop, device):
def main(args):

# Load dataset.
if args.cpu_data and args.cuda:
if args.cpu_data or not args.cuda:
device = torch.device("cpu")
else:
device = None
device = torch.device("cuda")
if args.test:
dataset = generate_data(args.small, args.include_stop, device)
else:
Expand All @@ -84,7 +84,7 @@ def main(args):
# Specific data split seed, for comparability across models and
# parameter initializations.
pyro.set_rng_seed(args.rng_data_seed)
indices = torch.randperm(sum(data_lengths)).tolist()
indices = torch.randperm(sum(data_lengths), device=device).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length) : offset])
for offset, length in zip(
Expand Down Expand Up @@ -131,7 +131,12 @@ def main(args):
)
n_epochs = args.n_epochs
losses = model.fit_svi(
dataset_train, n_epochs, args.anneal, args.batch_size, scheduler, args.jit
dataset_train,
n_epochs,
args.anneal,
args.batch_size,
scheduler,
args.jit,
)

# Evaluate.
Expand Down Expand Up @@ -233,13 +238,18 @@ def main(args):
)
with open(
os.path.join(
args.out_folder, "FactorMuE_results.input_{}.txt".format(time_stamp)
args.out_folder,
"FactorMuE_results.input_{}.txt".format(time_stamp),
),
"w",
) as ow:
ow.write("[args]\n")
args.latent_seq_length = model.latent_seq_length
args.latent_alphabet = model.latent_alphabet_length
for elem in list(args.__dict__.keys()):
ow.write("{} = {}\n".format(elem, args.__getattribute__(elem)))
ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet)))
ow.write("max_length = {}\n".format(dataset.max_length))


if __name__ == "__main__":
Expand Down
13 changes: 8 additions & 5 deletions examples/contrib/mue/ProfileHMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
Cambridge university press
[2] E. N. Weinstein, D. S. Marks (2021)
"Generative probabilistic biological sequence models that account for
mutational variability"
"A structured observation distribution for generative biological sequence
prediction and forecasting"
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf
"""

Expand Down Expand Up @@ -68,10 +68,10 @@ def main(args):
pyro.set_rng_seed(args.rng_seed)

# Load dataset.
if args.cpu_data and args.cuda:
if args.cpu_data or not args.cuda:
device = torch.device("cpu")
else:
device = None
device = torch.device("cuda")
if args.test:
dataset = generate_data(args.small, args.include_stop, device)
else:
Expand All @@ -90,7 +90,7 @@ def main(args):
# Specific data split seed, for comparability across models and
# parameter initializations.
pyro.set_rng_seed(args.rng_data_seed)
indices = torch.randperm(sum(data_lengths)).tolist()
indices = torch.randperm(sum(data_lengths), device=device).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length) : offset])
for offset, length in zip(
Expand Down Expand Up @@ -200,8 +200,11 @@ def main(args):
"w",
) as ow:
ow.write("[args]\n")
args.latent_seq_length = model.latent_seq_length
for elem in list(args.__dict__.keys()):
ow.write("{} = {}\n".format(elem, args.__getattribute__(elem)))
ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet)))
ow.write("max_length = {}\n".format(dataset.max_length))


if __name__ == "__main__":
Expand Down
57 changes: 57 additions & 0 deletions pyro/contrib/mue/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,60 @@ def __len__(self):
def __getitem__(self, ind):

return (self.seq_data[ind], self.L_data[ind])


def write(x, alphabet, file, truncate_stop=False, append=False, scores=None):
"""
Write sequence samples to file.
:param ~torch.Tensor x: One-hot encoded sequences, with size
``(data_size, seq_length, alphabet_length)``. May be padded with
zeros for variable length sequences.
:param ~np.array alphabet: Alphabet.
:param str file: Output file, where sequences will be written
in fasta format.
:param bool truncate_stop: If True, sequences will be truncated at the
first stop symbol (i.e. the stop symbol and everything after will not
be written). If False, the whole sequence will be written, including
any internal stop symbols.
:param bool append: If True, sequences are appended to the end of the
output file. If False, the file is first erased.
"""
print_alphabet = np.array(list(alphabet) + [""])
x = torch.cat([x, torch.zeros(list(x.shape[:2]) + [1])], -1)
if truncate_stop:
mask = (
torch.cumsum(
torch.matmul(
x, torch.tensor(print_alphabet == "*", dtype=torch.double)
),
-1,
)
> 0
).to(torch.double)
x = x * (1 - mask).unsqueeze(-1)
x[:, :, -1] = mask
else:
x[:, :, -1] = (torch.sum(x, -1) < 0.5).to(torch.double)
index = (
torch.matmul(x, torch.arange(x.shape[-1], dtype=torch.double))
.to(torch.long)
.cpu()
.numpy()
)
if scores is None:
seqs = [
">{}\n".format(j) + "".join(elem) + "\n"
for j, elem in enumerate(print_alphabet[index])
]
else:
seqs = [
">{}\n".format(j) + "".join(elem) + "\n"
for j, elem in zip(scores, print_alphabet[index])
]
if append:
open_flag = "a"
else:
open_flag = "w"
with open(file, open_flag) as fw:
fw.write("".join(seqs))
205 changes: 205 additions & 0 deletions pyro/contrib/mue/missingdatahmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import Categorical, OneHotCategorical

from pyro.distributions import constraints
from pyro.distributions.hmm import _sequential_logmatmulexp
Expand Down Expand Up @@ -110,3 +111,207 @@ def log_prob(self, value):
# Marginalize out final state.
result = result.logsumexp(-1)
return result

def sample(self, sample_shape=torch.Size([])):
"""
:param ~torch.Size sample_shape: Sample shape, last dimension must be
``num_steps`` and must be broadcastable to
``(batch_size, num_steps)``. batch_size must be int not tuple.
"""
# shape: batch_size x num_steps x categorical_size
shape = broadcast_shape(
torch.Size(list(self.batch_shape) + [1, 1]),
torch.Size(list(sample_shape) + [1]),
torch.Size((1, 1, self.event_shape[-1])),
)
# state: batch_size x state_dim
state = OneHotCategorical(logits=self.initial_logits).sample()
# sample: batch_size x num_steps x categorical_size
sample = torch.zeros(shape)
for i in range(shape[-2]):
# batch_size x 1 x state_dim @
# batch_size x state_dim x categorical_size
obs_logits = torch.matmul(
state.unsqueeze(-2), self.observation_logits
).squeeze(-2)
sample[:, i, :] = OneHotCategorical(logits=obs_logits).sample()
# batch_size x 1 x state_dim @
# batch_size x state_dim x state_dim
trans_logits = torch.matmul(
state.unsqueeze(-2), self.transition_logits
).squeeze(-2)
state = OneHotCategorical(logits=trans_logits).sample()

return sample

def filter(self, value):
"""
Compute the marginal probability of the state variable at each
step conditional on the previous observations.
:param ~torch.Tensor value: One-hot encoded observation.
Must be real-valued (float) and broadcastable to
``(batch_size, num_steps, categorical_size)`` where
``categorical_size`` is the dimension of the categorical output.
"""
# batch_size x num_steps x state_dim
shape = broadcast_shape(
torch.Size(list(self.batch_shape) + [1, 1]),
torch.Size(list(value.shape[:-1]) + [1]),
torch.Size((1, 1, self.initial_logits.shape[-1])),
)
filter = torch.zeros(shape)

# Combine observation and transition factors.
# batch_size x num_steps x state_dim
value_logits = torch.matmul(
value, torch.transpose(self.observation_logits, -2, -1)
)
# batch_size x num_steps-1 x state_dim x state_dim
result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :]

# Forward pass. (This could be parallelized using the
# Sarkka & Garcia-Fernandez method.)
filter[..., 0, :] = self.initial_logits + value_logits[..., 0, :]
filter[..., 0, :] = filter[..., 0, :] - torch.logsumexp(
filter[..., 0, :], -1, True
)
for i in range(1, shape[-2]):
filter[..., i, :] = torch.logsumexp(
filter[..., i - 1, :, None] + result[..., i - 1, :, :], -2
)
filter[..., i, :] = filter[..., i, :] - torch.logsumexp(
filter[..., i, :], -1, True
)
return filter

def smooth(self, value):
"""
Compute posterior expected value of state at each position (smoothing).
:param ~torch.Tensor value: One-hot encoded observation.
Must be real-valued (float) and broadcastable to
``(batch_size, num_steps, categorical_size)`` where
``categorical_size`` is the dimension of the categorical output.
"""
# Compute filter and initialize.
filter = self.filter(value)
shape = filter.shape
backfilter = torch.zeros(shape)

# Combine observation and transition factors.
# batch_size x num_steps x state_dim
value_logits = torch.matmul(
value, torch.transpose(self.observation_logits, -2, -1)
)
# batch_size x num_steps-1 x state_dim x state_dim
result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :]
# Construct backwards filter.
for i in range(shape[-2] - 1, 0, -1):
backfilter[..., i - 1, :] = torch.logsumexp(
backfilter[..., i, None, :] + result[..., i - 1, :, :], -1
)

# Compute smoothed version.
smooth = filter + backfilter
smooth = smooth - torch.logsumexp(smooth, -1, True)
return smooth

def sample_states(self, value):
"""
Sample states with forward filtering-backward sampling algorithm.
:param ~torch.Tensor value: One-hot encoded observation.
Must be real-valued (float) and broadcastable to
``(batch_size, num_steps, categorical_size)`` where
``categorical_size`` is the dimension of the categorical output.
"""
filter = self.filter(value)
shape = filter.shape
joint = filter.unsqueeze(-1) + self.transition_logits.unsqueeze(-3)
states = torch.zeros(shape[:-1], dtype=torch.long)
states[..., -1] = Categorical(logits=filter[..., -1, :]).sample()
for i in range(shape[-2] - 1, 0, -1):
logits = torch.gather(
joint[..., i - 1, :, :],
-1,
states[..., i, None, None]
* torch.ones([shape[-1], 1], dtype=torch.long),
).squeeze(-1)
states[..., i - 1] = Categorical(logits=logits).sample()
return states

def map_states(self, value):
"""
Compute maximum a posteriori (MAP) estimate of state variable with
Viterbi algorithm.
:param ~torch.Tensor value: One-hot encoded observation.
Must be real-valued (float) and broadcastable to
``(batch_size, num_steps, categorical_size)`` where
``categorical_size`` is the dimension of the categorical output.
"""
# Setup for Viterbi.
# batch_size x num_steps x state_dim
shape = broadcast_shape(
torch.Size(list(self.batch_shape) + [1, 1]),
torch.Size(list(value.shape[:-1]) + [1]),
torch.Size((1, 1, self.initial_logits.shape[-1])),
)
state_logits = torch.zeros(shape)
state_traceback = torch.zeros(shape, dtype=torch.long)

# Combine observation and transition factors.
# batch_size x num_steps x state_dim
value_logits = torch.matmul(
value, torch.transpose(self.observation_logits, -2, -1)
)
# batch_size x num_steps-1 x state_dim x state_dim
result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :]

# Forward pass.
state_logits[..., 0, :] = self.initial_logits + value_logits[..., 0, :]
for i in range(1, shape[-2]):
transit_weights = (
state_logits[..., i - 1, :, None] + result[..., i - 1, :, :]
)
state_logits[..., i, :], state_traceback[..., i, :] = torch.max(
transit_weights, -2
)
# Traceback.
map_states = torch.zeros(shape[:-1], dtype=torch.long)
map_states[..., -1] = torch.argmax(state_logits[..., -1, :], -1)
for i in range(shape[-2] - 1, 0, -1):
map_states[..., i - 1] = torch.gather(
state_traceback[..., i, :], -1, map_states[..., i].unsqueeze(-1)
).squeeze(-1)
return map_states

def given_states(self, states):
"""
Distribution conditional on the state variable.
:param ~torch.Tensor map_states: State trajectory. Must be
integer-valued (long) and broadcastable to
``(batch_size, num_steps)``.
"""
shape = broadcast_shape(
list(self.batch_shape) + [1, 1],
list(states.shape[:-1]) + [1, 1],
[1, 1, self.observation_logits.shape[-1]],
)
states_index = states.unsqueeze(-1) * torch.ones(shape, dtype=torch.long)
obs_logits = self.observation_logits * torch.ones(shape)
logits = torch.gather(obs_logits, -2, states_index)
return OneHotCategorical(logits=logits)

def sample_given_states(self, states):
"""
Sample an observation conditional on the state variable.
:param ~torch.Tensor map_states: State trajectory. Must be
integer-valued (long) and broadcastable to
``(batch_size, num_steps)``.
"""
conditional = self.given_states(states)
return conditional.sample()
Loading

0 comments on commit b704945

Please sign in to comment.