Skip to content

andreajparker/pytorch-crf

Repository files navigation

pytorch-crf

https://travis-ci.org/kmkurn/pytorch-crf.svg?branch=master https://coveralls.io/repos/github/kmkurn/pytorch-crf/badge.svg?branch=master

Conditional random field in PyTorch.

Description

This package provides an implementation of conditional random field (CRF) in PyTorch. This implementation borrows mostly from AllenNLP CRF module with some modifications.

Requirements

  • Python 3.6
  • PyTorch 1.0.0

Installation

You can install with pip

pip install pytorch-crf

Or, you can install from Github directly

pip install git+https://github.com/kmkurn/pytorch-crf#egg=pytorch_crf

Examples

In the examples below, we will assume that these lines have been executed

>>> import torch
>>> from torchcrf import CRF
>>> seq_length, batch_size, num_tags = 3, 2, 5
>>> emissions = torch.randn(seq_length, batch_size, num_tags)
>>> tags = torch.tensor([
...   [0, 1], [2, 4], [3, 1]
... ], dtype=torch.long)  # (seq_length, batch_size)
>>> model = CRF(num_tags)

Computing log likelihood

>>> model(emissions, tags)
tensor(-12.7431, grad_fn=<SumBackward0>)

Computing log likelihood with mask

>>> mask = torch.tensor([
...   [1, 1], [1, 1], [1, 0]
... ], dtype=torch.uint8)  # (seq_length, batch_size)
>>> model(emissions, tags, mask=mask)
tensor(-10.8390, grad_fn=<SumBackward0>)

Decoding

>>> model.decode(emissions)
[[3, 1, 3], [0, 1, 0]]

Decoding with mask

>>> model.decode(emissions, mask=mask)
[[3, 1, 3], [0, 1]]

See tests/test_crf.py for more examples.

License

MIT. See LICENSE for details.

Contributing

Contributions are welcome! Please follow these instructions to install dependencies and running the tests and linter. Make a pull request to develop branch once your contribution is ready.

Installing dependencies

Make sure you setup a virtual environment with Python and PyTorch installed. Then, install all the dependencies in requirements.txt file and install this package in development mode.

pip install -r requirements.txt
pip install -e .

Setup pre-commit hook

Simply run

ln -s ../../pre-commit.sh .git/hooks/pre-commit

Running tests

Run pytest in the project root directory.

Running linter

Run flake8 in the project root directory. This will also run mypy, thanks to flake8-mypy package.

About

(Linear-chain) Conditional random field in PyTorch.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 98.5%
  • Other 1.5%