'How to pass custom weights to scikit_learn wrappers (e.g. KerasClassifier) in a multilabel classification problem
I'm building a chain classifier for a multiclass problem that uses Keras binary Classifier model in a chain. I have 17 labels as classification target and dataset is an imbalanced dataset for these classes. I want to customize weights and train my chain classifier model based on these weights. Normally, when not using scikit_learn wrappers, I pass the custom weights to the fit function.
This is the code to generate weights for these classes:
from sklearn.utils import class_weight
y_ints = [y.argmax() for y in y_train]
class_weights = class_weight.compute_class_weight(custom_weight_dict,
np.unique(y_ints),
y_ints)
and here is my model that takes keras model as input and have a chain of binary classifiers.
def create_model():
input_size=length_long_sentence
embedding_size=128
lstm_size=64
output_size=len(unique_tag_set)
#----------------------------Model--------------------------------
current_input=Input(shape=(input_size,))
emb_current = Embedding(vocab_size, embedding_size, input_length=input_size)(current_input)
out_current=Bidirectional(LSTM(units=lstm_size))(emb_current )
#out_current = Reshape((1,2*lstm_size))(out_current)
output = Dense(units=1, activation= 'sigmoid')(out_current)
#output = Dense(units=1, activation='softmax')(out_current)
model = Model(inputs=current_input, outputs=output)
#-------------------------------compile-------------
model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
model = KerasClassifier(build_fn=create_model, epochs=1,batch_size=256, shuffle = True, verbose = 1,validation_split=0.2)
chain=ClassifierChain(model, order=multi_label_order, random_state=42)
history=chain.fit(X_train, y_train)
The fit method of chain classifier only take Train features and train labels as input. Is their anyway that I can pass my class weights so that it can be used during training so accuracy of rare classes can be improved?
Solution 1:[1]
Keras Model.fit routine takes class_weight
argument the same way as sklearn fit
routines do.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | RomanS |