Skip to content

An Inplementation of CRF (Conditional Random Fields) in PyTorch 1.0

License

Notifications You must be signed in to change notification settings

rikeda71/TorchCRF

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

72 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Torch CRF

CircleCI Coverage Status MIT License

Python Versions PyPI version

Implementation of CRF (Conditional Random Fields) in PyTorch

Requirements

  • python3 (>=3.6)
  • PyTorch (>=1.0)

Installation

$ pip install TorchCRF

Usage

>>> import torch
>>> from TorchCRF import CRF
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> batch_size = 2
>>> sequence_size = 3
>>> num_labels = 5
>>> mask = torch.ByteTensor([[1, 1, 1], [1, 1, 0]]).to(device) # (batch_size. sequence_size)
>>> labels = torch.LongTensor([[0, 2, 3], [1, 4, 1]]).to(device)  # (batch_size, sequence_size)
>>> hidden = torch.randn((batch_size, sequence_size, num_labels), requires_grad=True).to(device)
>>> crf = CRF(num_labels)

Computing log-likelihood (used where forward)

>>> crf.forward(hidden, labels, mask)
tensor([-7.6204, -3.6124], device='cuda:0', grad_fn=<ThSubBackward>)

Decoding (predict labels of sequences)

>>> crf.viterbi_decode(hidden, mask)
[[0, 2, 2], [4, 0]]

License

MIT

References

About

An Inplementation of CRF (Conditional Random Fields) in PyTorch 1.0

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •  

Languages