Conditional random field in PyTorch.
This package provides an implementation of conditional random field (CRF) in PyTorch. This implementation borrows mostly from AllenNLP CRF module with some modifications.
- Python 3.6
- PyTorch 1.0.0
You can install with pip
pip install pytorch-crf
Or, you can install from Github directly
pip install git+https://github.com/kmkurn/pytorch-crf#egg=pytorch_crf
In the examples below, we will assume that these lines have been executed
>>> import torch
>>> from torchcrf import CRF
>>> seq_length, batch_size, num_tags = 3, 2, 5
>>> emissions = torch.randn(seq_length, batch_size, num_tags)
>>> tags = torch.tensor([
... [0, 1], [2, 4], [3, 1]
... ], dtype=torch.long) # (seq_length, batch_size)
>>> model = CRF(num_tags)
>>> model(emissions, tags)
tensor(-12.7431, grad_fn=<SumBackward0>)
>>> mask = torch.tensor([
... [1, 1], [1, 1], [1, 0]
... ], dtype=torch.uint8) # (seq_length, batch_size)
>>> model(emissions, tags, mask=mask)
tensor(-10.8390, grad_fn=<SumBackward0>)
>>> model.decode(emissions)
[[3, 1, 3], [0, 1, 0]]
>>> model.decode(emissions, mask=mask)
[[3, 1, 3], [0, 1]]
See tests/test_crf.py
for more examples.
MIT. See LICENSE for details.
Contributions are welcome! Please follow these instructions to install
dependencies and running the tests and linter. Make a pull request to
develop
branch once your contribution is ready.
Make sure you setup a virtual environment with Python and PyTorch
installed. Then, install all the dependencies in requirements.txt
file and
install this package in development mode.
pip install -r requirements.txt pip install -e .
Simply run
ln -s ../../pre-commit.sh .git/hooks/pre-commit
Run pytest
in the project root directory.
Run flake8
in the project root directory. This will also run mypy
,
thanks to flake8-mypy
package.