'How do I retrieve the one-hot-encoded feature names in tensorflow.keras preprocessing layers
What is the tf.keras equivalent of encoder.get_feature_names found in sklearn? As shown shown in this SO question Need this to get all the one-hot encoded feature names created "from tensorflow.keras.layers.experimental import preprocessing". I will appreciate any post with example.
Solution 1:[1]
Here is a way to do this. A few relevant notes:
I'm using the California housing dataset from scikit-learn, I haven't included that portion of the code, but train_ds is already split into X, y and the dataset is already set up.
I setup up this particular feature_layer using feature_columns, which I know is deprecated in favor of preprocessing layers in keras. It was instantiated using:
feature_layer = tf.keras.layers.DenseFeatures(
dictionary_of_all_feature_columns.values())
I don't believe this should affect the use of the function below in other contexts, but I'll test it as I get to preprocessing layers in Keras. It also only handles numeric, one-hot encoded, and bucketized feature columns. I'm at the beginning of my study of feature engineering, so have not experienced too much diversity on variable types. Will update here as I make progress over there.
import pandas as pd
# Get one batch of what we're going to process
[(X, y)] = train_ds.take(1)
def get_feature_names(feature_layer):
"""
Takes as input a preprocessing layer of type
keras.feature_column.dense_features_v2.DenseFeatures
and returns as output a list containing variable names in the order that
they are processed by the layer. Only deals with numeric and categorical
variables. The one-hot encoded variable names match how scikit-learn
names them, namely variable_name + _ + category_in_vocabulary_list and similarly with bucket ranges.
"""
feature_list = []
feature_layer_dictionary = feature_layer.get_config()
list_of_encoded_features = feature_layer_dictionary["feature_columns"]
for encoded_feature in list_of_encoded_features:
class_name = encoded_feature["class_name"]
if class_name == 'NumericColumn':
feature_list.append(encoded_feature['config']['key'])
elif class_name == 'IndicatorColumn':
variable_name = encoded_feature['config'][
'categorical_column']['config']['key']
category_list = list(
encoded_feature['config']
['categorical_column']['config']
['vocabulary_list'])
for category in category_list:
feature_list.append(variable_name + "_" + category)
elif class_name == 'BucketizedColumn':
variable_name = encoded_feature['config'][
'source_column']['config']['key']
boundary_list = list(encoded_feature['config']['boundaries'])
boundary_list.insert(0, -np.inf)
boundary_list.append(np.inf)
for i in range(len(boundary_list) - 1):
begin_bucket = str(boundary_list[i])
end_bucket = str(boundary_list[i+1])
both_sides_bucket = begin_bucket + "_" + end_bucket
encoded_variable_name = variable_name + "_" + both_sides_bucket
feature_list.append(encoded_variable_name)
else:
pass
return feature_list
preprocessed_df = pd.DataFrame(X, columns=list(X.keys()))
postprocessed_df = pd.DataFrame(
feature_layer(X),
columns=get_feature_names(feature_layer))
[ins] In [344]: preprocessed_df.iloc[0]
Out[344]:
longitude -119.3
latitude 36.34
housing_median_age 45.0
total_rooms 3723.0
total_bedrooms 831.0
population 2256.0
households 770.0
median_income 1.8299
ocean_proximity b'INLAND'
Name: 0, dtype: object
[ins] In [345]: postprocessed_df.iloc[0]
Out[345]:
households 0.707083
housing_median_age_-inf_2 0.000000
housing_median_age_2_5 0.000000
housing_median_age_5_10 0.000000
housing_median_age_10_15 0.000000
housing_median_age_15_20 0.000000
housing_median_age_20_30 0.000000
housing_median_age_30_40 0.000000
housing_median_age_40_inf 1.000000
latitude 0.334551
longitude 0.132462
median_income -1.066948
ocean_proximity_<1H OCEAN 0.000000
ocean_proximity_INLAND 1.000000
ocean_proximity_NEAR OCEAN 0.000000
ocean_proximity_NEAR BAY 0.000000
ocean_proximity_ISLAND 0.000000
population 0.724704
total_bedrooms 0.697491
total_rooms 0.501156
Name: 0, dtype: float32
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |