[英]How to use existing keras model in tensorflow.js
I have Keras model which I have converted to tensorflow.js but could not load the model in javascript, what will be the steps for that? I have Keras model which I have converted to tensorflow.js but could not load the model in javascript, what will be the steps for that?
model.add(Embedding(vocabulary_size, seq_len, input_length=seq_len))
model.add(LSTM(256,return_sequences=True))
model.add(LSTM(128))
model.add(Dense(256,activation='relu'))
model.add(Dense(vocabulary_size, activation='softmax'))
# compiling the network
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_inputs,train_targets,epochs=256,verbose=1)
You can load the model from an endpoint like so.您可以像这样从端点加载 model。
import * as tf from '@tensorflow/tfjs';
import React, {useState, useEffect} from "react";
const url = {
model: 'http://localhost:81/tfjs-models/model.json',
};
async function loadModel(url) {
try {
// For layered model
const model = await tf.loadLayersModel(url.model);
// For graph model
// const model = await tf.loadGraphModel(url.model);
setModel(model);
console.log("Load model success");
} catch (err) {
console.log(err);
}
}
const [model, setModel] = useState();
useEffect(() => {
tf.ready().then(() => {
loadModel(url);
});
}, []);
The model can then be accessed by using the model
from state.然后可以使用model
访问 model。
import tensorflowjs as tfjs
model.add(Embedding(vocabulary_size, seq_len, input_length=seq_len))
model.add(LSTM(256,return_sequences=True))
model.add(LSTM(128))
model.add(Dense(256,activation='relu'))
model.add(Dense(vocabulary_size, activation='softmax'))
# compiling the network
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_inputs,train_targets,epochs=256,verbose=1)
# save the model in model.json
tfjs.converters.save_keras_model(model, './keras_converted')
Load the model in javascript在 javascript 中加载 model
import * as tf from '@tensorflow/tfjs';
const model = await tf.loadLayersModel('https://hostname:port/path/to/model.json');
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.