Skip to content

Commit

Permalink
Specify torch's types explicitly bc they are mocked
Browse files Browse the repository at this point in the history
- Close kmkurn#14
  • Loading branch information
kmkurn committed Feb 4, 2019
1 parent 05bfb49 commit 805c9a8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
1 change: 0 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ API documentation

.. autoclass:: torchcrf.CRF
:members:
:show-inheritance:

Indices and tables
==================
Expand Down
29 changes: 15 additions & 14 deletions torchcrf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,22 @@ def forward(
"""Compute the conditional log likelihood of a sequence of tags given emission scores.
Args:
emissions: Emission score tensor of size ``(seq_length, batch_size, num_tags)``
if ``batch_first`` is ``False``, ``(batch_size, seq_length, num_tags)``
otherwise.
tags: Sequence of tags tensor of size ``(seq_length, batch_size)`` if
``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
mask: Mask tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is
``False``, ``(batch_size, seq_length)`` otherwise.
emissions (`~torch.Tensor`): Emission score tensor of size
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length, num_tags)`` otherwise.
tags (`~torch.LongTensor`): Sequence of tags tensor of size
``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length)`` otherwise.
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
reduction: Specifies the reduction to apply to the output:
``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
``sum``: the output will be summed over batches. ``mean``: the output will be
averaged over batches. ``token_mean``: the output will be averaged over tokens.
Returns:
The log likelihood. This will have size ``(batch_size,)`` if reduction is ``none``,
``()`` otherwise.
`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
reduction is ``none``, ``()`` otherwise.
"""
self._validate(emissions, tags=tags, mask=mask)
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
Expand Down Expand Up @@ -116,11 +117,11 @@ def decode(self, emissions: torch.Tensor,
"""Find the most likely tag sequence using Viterbi algorithm.
Args:
emissions: Emission score tensor of size ``(seq_length, batch_size, num_tags)``
if ``batch_first`` is ``False``, ``(batch_size, seq_length, num_tags)``
otherwise.
mask: Mask tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is
``False``, ``(batch_size, seq_length)`` otherwise.
emissions (`~torch.Tensor`): Emission score tensor of size
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length, num_tags)`` otherwise.
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
Returns:
List of list containing the best tag sequence for each batch.
Expand Down

0 comments on commit 805c9a8

Please sign in to comment.