简体   繁体   中英

TensorFlow.js - Using pretrained ResNet50 network

I created a Neural Network using TensorFlow via Keras API in Python that leverages the ResNet50 pretrained network to be able to classify 133 different breeds of dogs.

I now want to be able to deploy this model so that it can be used through TensorFlow.js, however I'm having difficulties in getting ResNet50 to work. I'm capable of being able to transfer a NN that I created from scratch to TensorFlow.js without a problem, but transferring one using a pretrained network isn't as straightforward.

Here is the Python code that I'm trying to adapt:

from keras.applications.resnet50 import ResNet50
ResNet50_model = ResNet50(weights="imagenet") # download ImageNet challenge weights

def extractResNet50(tensor): # tensor shape is (1, 224, 224, 3)
    return ResNet50(weights='imagenet', include_top=False, pooling="avg").predict(preprocess_input(tensor))

def dogBreed(img_path):
    tensor = convertToTensor(img_path) # I can do this in TF.js already with no issue

    resnetTensor = extractResNet50(tensor) # tensor returned in shape (1, 2048)
    resnetTensor = np.expand_dims(resnetTensor, axis=0) # repeat this line 2 more times to get shape (1, 1, 1, 1, 2048)

    # code below I can convert to TF.js without issue
    prediction = model.predict(resnetTensor[0])

How would I convert everything above, except for code lines 1 and 4 of dogBreed() , to be used in TensorFlow.js?

Resnet is such a big network that it has not been imported yet on the browser and I doubt if it one day will. At least it is not as for the latest version of tensorflowJs ( version 0.14 )

One you can do on the other hand is to save your Python keras model and then to import the frozen model on Js for prediction.

Update: You are using resnet50 as the feature extractor for your model. In that case the frozen model that you will save needs to contain both Resnet50 and your model topology and weights.

1- Instead of having two separated architecture in python, create only one network using tensorflow directly and not keras. Then the frozen model will contain Resnet. This might not work properly in the browser as the size of Resnet is quite big (I have not tested it myself)

2- Instead of using Resnet in the browser, consider using coco-ssd or mobilenet that can be used in the browser as feature-extractor. You can see how to use them on the official repo

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM