'Mask R-CNN is not loading weights properly for inference and re-training
QUESTION:
I'm new to the world of computer vision and this is my second project with it. I am running an edited version of the Matterport Mask RCNN that runs with tensorflow-gpu==2.7.0. (Found out later it would have worked out just fine with an older version) I am trying to use this with a pen data set I created.
Anyway, the problem I am having is whenever I load the trained weights into the model to resume training it, the metrics all skyrocket back up. I am also getting bad predictions loading them for inference as well. Why are my weights not loading or saving properly? I am saving the weights using callbacks and loading them using the following:
model = modellib.MaskRCNN(mode="inference",
config=inference_config,
model_dir=MODEL_DIR)
# Get path to saved weights
model_path = model.find_last()
# Load trained weights
print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)
WHAT I'VE TRIED:
I have tried saving the whole model by changing the save_weights_only
in the callbacks to False
. I ran into the get_config()
issue in this thread and followed through on some of those solutions, but to no avail.
I have also tried messing around with image sizes and epoch number as well.
I have tried saving the model using:
from tensorflow import keras
model.keras_model.save(complete filepath)
model = keras.models.load_model('path/to/location')
which led to the same get_config()
issue.
RESOURCES:
Here is a list of the things I am running:
# ITEM ########### VERSION ##########################
# Python # 3.9.7 #
# conda # 4.10.3 #
# CUDA # 11.4 #
# WindowsOS # 11 #
# cuDNN # 8.2.4 #
#####################################################
################################### PACKAGES ##################################
# packages in environment at C:\Users\ecsan\anaconda3\envs\Prototype:
# Command: conda list
# Name #################### Version ################ Build # Channel ############
# absl-py 1.0.0 pypi_0 pypi #
# alabaster 0.7.12 pypi_0 pypi #
# argon2-cffi 21.1.0 pypi_0 pypi #
# astunparse 1.6.3 pypi_0 pypi #
# attrs 21.2.0 pypi_0 pypi #
# babel 2.9.1 pypi_0 pypi #
# backcall 0.2.0 pypi_0 pypi #
# bleach 4.1.0 pypi_0 pypi #
# ca-certificates 2021.10.8 h5b45459_0 conda-forge #
# cachetools 4.2.4 pypi_0 pypi #
# certifi 2021.10.8 pypi_0 pypi #
# cffi 1.15.0 pypi_0 pypi #
# charset-normalizer 2.0.9 pypi_0 pypi #
# colorama 0.4.4 pypi_0 pypi #
# console_shortcut 0.1.1 4 #
# cycler 0.11.0 pypi_0 pypi #
# cython 0.29.25 pypi_0 pypi #
# debugpy 1.5.1 pypi_0 pypi #
# decorator 5.1.0 pypi_0 pypi #
# defusedxml 0.7.1 pypi_0 pypi #
# dill 0.3.4 pypi_0 pypi #
# docutils 0.17.1 pypi_0 pypi #
# entrypoints 0.3 pypi_0 pypi #
# flatbuffers 2.0 pypi_0 pypi #
# fonttools 4.28.3 pypi_0 pypi #
# gast 0.4.0 pypi_0 pypi #
# google-auth 2.3.3 pypi_0 pypi #
# google-auth-oauthlib 0.4.6 pypi_0 pypi #
# google-pasta 0.2.0 pypi_0 pypi #
# grpcio 1.42.0 pypi_0 pypi #
# h5py 3.6.0 pypi_0 pypi #
# idna 3.3 pypi_0 pypi #
# imageio 2.13.2 pypi_0 pypi #
# imagesize 1.3.0 pypi_0 pypi #
# imgaug 0.4.0 pypi_0 pypi #
# importlib-metadata 4.8.2 pypi_0 pypi #
# ipykernel 6.6.0 pypi_0 pypi #
# ipyparallel 8.0.0 pypi_0 pypi #
# ipython 7.30.1 pypi_0 pypi #
# ipython-genutils 0.2.0 pypi_0 pypi #
# ipywidgets 7.6.5 pypi_0 pypi #
# jedi 0.18.1 pypi_0 pypi #
# jinja2 3.0.3 pypi_0 pypi #
# joblib 1.1.0 pypi_0 pypi #
# jsonschema 4.2.1 pypi_0 pypi #
# jupyter-client 7.1.0 pypi_0 pypi #
# jupyter-core 4.9.1 pypi_0 pypi #
# jupyterlab-pygments 0.1.2 pypi_0 pypi #
# jupyterlab-widgets 1.0.2 pypi_0 pypi #
# keras 2.7.0 pypi_0 pypi #
# keras-preprocessing 1.1.2 pypi_0 pypi #
# kiwisolver 1.3.2 pypi_0 pypi #
# libclang 12.0.0 pypi_0 pypi #
# markdown 3.3.6 pypi_0 pypi #
# markupsafe 2.0.1 pypi_0 pypi #
# matplotlib 3.5.0 pypi_0 pypi #
# matplotlib-inline 0.1.3 pypi_0 pypi #
# mistune 0.8.4 pypi_0 pypi #
# nbclient 0.5.9 pypi_0 pypi #
# nbconvert 6.3.0 pypi_0 pypi #
# nbformat 5.1.3 pypi_0 pypi #
# nest-asyncio 1.5.4 pypi_0 #
# networkx 2.6.3 pypi_0 pypi #
# nose 1.3.7 pypi_0 pypi #
# notebook 6.4.6 pypi_0 pypi #
# numpy 1.19.5 pypi_0 pypi #
# oauthlib 3.1.1 pypi_0 pypi #
# opencv-python 4.5.4.60 pypi_0 pypi #
# openssl 3.0.0 h8ffe710_2 conda-forge #
# opt-einsum 3.3.0 pypi_0 pypi #
# packaging 21.3 pypi_0 pypi #
# pandocfilters 1.5.0 pypi_0 pypi #
# parso 0.8.3 pypi_0 pypi #
# pickleshare 0.7.5 pypi_0 pypi #
# pillow 8.4.0 pypi_0 pypi #
# pip 21.3.1 pyhd8ed1ab_0 conda-forge #
# prometheus-client 0.12.0 pypi_0 pypi #
# prompt-toolkit 3.0.23 pypi_0 pypi #
# protobuf 3.19.1 pypi_0 pypi #
# psutil 5.8.0 pypi_0 pypi #
# pyasn1 0.4.8 pypi_0 pypi #
# pyasn1-modules 0.2.8 pypi_0 pypi #
# pycparser 2.21 pypi_0 pypi #
# pygments 2.10.0 pypi_0 pypi #
# pyparsing 3.0.6 pypi_0 pypi #
# pyrsistent 0.18.0 pypi_0 pypi #
# python 3.9.7 h900ac77_3_cpython conda-forge #
# python-dateutil 2.8.2 pypi_0 pypi #
# python_abi 3.9 2_cp39 conda-forge #
# pytz 2021.3 pypi_0 pypi #
# pywavelets 1.2.0 pypi_0 pypi #
# pywin32 302 pypi_0 pypi #
# pywinpty 1.1.6 pypi_0 pypi #
# pyzmq 22.3.0 pypi_0 pypi #
# qtconsole 5.2.1 pypi_0 pypi #
# qtpy 1.11.3 pypi_0 pypi #
# requests 2.26.0 pypi_0 pypi #
# requests-oauthlib 1.3.0 pypi_0 pypi #
# rsa 4.8 pypi_0 pypi #
# scikit-image 0.18.3 pypi_0 pypi #
# scipy 1.7.3 pypi_0 pypi #
# send2trash 1.8.0 pypi_0 pypi #
# setuptools 59.4.0 py39hcbf5309_0 conda-forge #
# setuptools-scm 6.3.2 pypi_0 pypi #
# shapely 1.8.0 pypi_0 pypi #
# six 1.15.0 pypi_0 pypi #
# snowballstemmer 2.2.0 pypi_0 pypi #
# sphinx 4.3.1 pypi_0 pypi #
# sphinxcontrib-applehelp 1.0.2 pypi_0 pypi #
# sphinxcontrib-devhelp 1.0.2 pypi_0 pypi #
# sphinxcontrib-htmlhelp 2.0.0 pypi_0 pypi #
# sphinxcontrib-jsmath 1.0.1 pypi_0 pypi #
# sphinxcontrib-qthelp 1.0.3 pypi_0 pypi #
# sphinxcontrib-serializinghtml 1.1.5 pypi_0 pypi #
# sqlite 3.37.0 h8ffe710_0 conda-forge #
# tb-nightly 2.8.0a20211220 pypi_0 pypi #
# tensorboard 2.7.0 pypi_0 pypi #
# tensorboard-data-server 0.6.1 pypi_0 pypi #
# tensorboard-plugin-wit 1.8.0 pypi_0 pypi #
# tensorflow-estimator 2.7.0 pypi_0 pypi #
# tensorflow-gpu 2.7.0 pypi_0 pypi #
# tensorflow-io-gcs-filesystem 0.23.1 pypi_0 pypi #
# termcolor 1.1.0 pypi_0 pypi #
# terminado 0.12.1 pypi_0 pypi #
# testpath 0.5.0 pypi_0 pypi #
# tf-estimator-nightly 2.8.0.dev2021122009 pypi_0 pypi #
# tifffile 2021.11.2 pypi_0 pypi #
# tomli 1.2.2 pypi_0 pypi #
# tornado 6.1 pypi_0 pypi #
# tqdm 4.62.3 pypi_0 pypi #
# traitlets 5.1.1 pypi_0 pypi #
# typing-extensions 4.0.1 pypi_0 pypi #
# tzdata 2021e he74cb21_0 conda-forge #
# ucrt 10.0.20348.0 h57928b3_0 conda-forge #
# urllib3 1.26.7 pypi_0 pypi #
# vc 14.2 hb210afc_5 conda-forge #
# vs2015_runtime 14.29.30037 h902a5da_5 conda-forge #
# wcwidth 0.2.5 pypi_0 pypi #
# webencodings 0.5.1 pypi_0 pypi #
# werkzeug 2.0.2 pypi_0 pypi #
# wheel 0.37.0 pyhd8ed1ab_1 conda-forge #
# widgetsnbextension 3.5.2 pypi_0 pypi #
# wrapt 1.13.3 pypi_0 pypi #
# zipp 3.6.0 pypi_0 pypi #
###############################################################################
Here is a link to my tensorboard and an example of a bad prediction:
You should see the model learning and then a spike at the end, that spike was when I loaded the weights again and resumed training.
https://tensorboard.dev/experiment/KkgugOP7RGu12lVCA6M29Q/
Here is my custom config for training:
class CustomConfig(Config):
"""Configuration for training on the toy shapes dataset.
Derives from the base Config class and overrides values specific
to the toy shapes dataset.
"""
"""Configuration for training on the dataset.
Derives from the base Config class and overrides some values.
"""
DETECTION_MIN_CONFIDENCE = 0.7 # Skip detections with < 90% confidence
# Give the configuration a recognizable name
NAME = "PEN"
# Train on 1 GPU and 8 images per GPU. We can put multiple images on each
# GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
GPU_COUNT = 1
IMAGES_PER_GPU = 8
# Number of classes (including background)
NUM_CLASSES = 1 + 1 # background + PEN
# Use small images for faster training. Set the limits of the small side
# the large side, and that determines the image shape.
IMAGE_MIN_DIM = 128
IMAGE_MAX_DIM = 128
# Use smaller anchors because our image and objects are small
RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128) # anchor side in pixels
# Reduce training ROIs per image because the images are small and have
# few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
TRAIN_ROIS_PER_IMAGE = 32
# Use a small epoch since the data is simple
STEPS_PER_EPOCH = 300
# use small validation steps since the epoch is small
VALIDATION_STEPS = 10
config = CustomConfig()
config.display()
Here is my inference config:
class InferenceConfig(CustomConfig):
NAME = "PEN"
NUM_CLASSES = 1 + 1 # background + PEN
# Use small images for faster training. Set the limits of the small side
# the large side, and that determines the image shape.
IMAGE_MIN_DIM = 128
IMAGE_MAX_DIM = 128
# Use smaller anchors because our image and objects are small
RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128) # anchor side in pixels
GPU_COUNT = 1
IMAGES_PER_GPU = 1
DETECTION_MIN_CONFIDENCE = 0.9
If you need additional information please let me know. This is also my first post and any guidance is appreciated.
Solution 1:[1]
I found the "bug". It was the annotations, I needed to use the VIA annotation tool with a deprecated version. It worked so much easier.
Solution 2:[2]
Please check the version of python, tensrorflow and keras that you are using because it will work on python 3.6 only.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Igneus |
Solution 2 | Shivani Srivastava |