'how to use CRF in tensorflow keras?
The code is like this:
import tensorflow as tf
from keras_contrib.layers import CRF
from tensorflow import keras
def create_model(max_seq_len, adapter_size=64):
    """Creates a classification model."""
    # adapter_size = 64  # see - arXiv:1902.00751
    # create the bert layer
    with tf.io.gfile.GFile(bert_config_file, "r") as reader:
        bc = StockBertConfig.from_json_string(reader.read())
        bert_params = map_stock_config_to_params(bc)
        bert_params.adapter_size = adapter_size
        bert = BertModelLayer.from_params(bert_params, name="bert")
    input_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="input_ids")
    # token_type_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="token_type_ids")
    # output         = bert([input_ids, token_type_ids])
    bert_output = bert(input_ids)
    print("bert_output.shape: {}".format(bert_output.shape))  # (?, 100, 768)
    crf = CRF(len(tag2idx))
    logits = crf(bert_output)
    model = keras.Model(inputs=input_ids, outputs=logits)
    model.build(input_shape=(None, max_seq_len))
    # load the pre-trained model weights
    load_stock_weights(bert, bert_ckpt_file)
    # freeze weights if adapter-BERT is used
    if adapter_size is not None:
        freeze_bert_layers(bert)
    model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
    model.summary()
    return model
I am using tensorflow keras and also use keras_contrib package, to do NER. it seems the tensorflow keras package does not work well with keras_contrib package.
The Traceback information is listed below:
Traceback (most recent call last):
  File "F:/_gitclone3/bert_examples/bert_ner_example_eval.py", line 120, in <module>
    model = create_model(max_seq_len, adapter_size=adapter_size)
  File "F:/_gitclone3/bert_examples/bert_ner_example_eval.py", line 101, in create_model
    logits = crf(bert_output)
  File "C:\Users\yuexiang\Anaconda3\lib\site-packages\keras\engine\base_layer.py", line 443, in __call__
    previous_mask = _collect_previous_mask(inputs)
  File "C:\Users\yuexiang\Anaconda3\lib\site-packages\keras\engine\base_layer.py", line 1311, in _collect_previous_mask
    mask = node.output_masks[tensor_index]
AttributeError: 'Node' object has no attribute 'output_masks'
How do I use CRF with tensorflow keras?
Solution 1:[1]
I run into a similar problem and spent a lot of time trying to get things to work. Here's what worked for me using python 3.6.5:
Seqeval:
pip install seqeval==0.0.5
Keras:
pip install keras==2.2.4
Keras-contrib (2.0.8):
git clone https://www.github.com/keras-team/keras-contrib.git
cd keras-contrib
python setup.py install
TensorFlow:
pip install tensorflow==1.14.0
Do pip list to make sure you have actually installed those versions (eg pip seqeval may automatically update your keras)
Then in your code import like so:
from keras.models import *
from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional, Input
from keras_contrib.layers import CRF
#etc.
Hope this helps, good luck!
Solution 2:[2]
You can try tensorflow add-ons.(If you are using tensorflow version 2). You can try tf-crf-layer (if you are using tensorflow==1.15.0)
Solution 3:[3]
They have it mentioned on their README.
git clone https://www.github.com/keras-team/keras-contrib.git
cd keras-contrib
python convert_to_tf_keras.py
USE_TF_KERAS=1 python setup.py install
Solution 4:[4]
I have gone through the possible solutions, mentioning which worked for me:
- Install tf2crf (https://pypi.org/project/tf2crf/): It provides a simple CRF layer for TensorFlow 2 keras.
- Use TensorFlow SIG Addons: ( https://www.tensorflow.org/addons/api_docs/python/tfa/layers/CRF): It provides the functionality that is not available in core TensorFlow.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source | 
|---|---|
| Solution 1 | brogeo | 
| Solution 2 | it doesn't matter | 
| Solution 3 | Shahidur | 
| Solution 4 | impyadav | 


