简体   繁体   中英

Load Amazon Sagemaker NTM model locally for inference

I have trained a Sagemaker NTM model which is a neural topic model, directly on the AWS sagemaker platform. Once training is complete you are able to download the mxnet model files. Once unpacked the files contain:

  • params
  • symbol.json
  • meta.json

I have followed the docs on mxnet to load the model and have the following code:

sym, arg_params, aux_params = mx.model.load_checkpoint('model_algo-1', 0)
module_model = mx.mod.Module(symbol=sym, label_names=None, context=mx.cpu())

module_model.bind(
    for_training=False,
    data_shapes=[('data', (1, VOCAB_SIZE))]
)

module_model.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True) # must set allow missing true here or receive an error for a missing n_epoch var

I now try and use the model for inference using:

module_model.predict(x) # where x is a numpy array of size (1, VOCAB_SIZE)

The code runs, but the result is just a single value, where I expect a distribution over topics:

[11.060672]
<NDArray 1 @cpu(0)>

EDIT:

I have tried to load it using the Symbol API, but still no luck:

import warnings
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    deserialized_net = gluon.nn.SymbolBlock.imports('model_algo-1-symbol.json', ['data'], 'model_algo-1-0000.params', ctx=mx.cpu())

Error:

AssertionError: Parameter 'n_epoch' is missing in file: model_algo-1-0000.params, which contains parameters: 'logsigma_bias', 'enc_0_bias', 'projection_bias', ..., 'enc_1_weight', 'enc_0_weight', 'mean_bias', 'logsigma_weight'. Please make sure source and target networks have the same prefix.

Any help would be great!

SageMaker does not support this use case. The model can be hosted on SageMaker for online inference or used to make predictions in batch with a transform job.

See more details:

  1. https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html
  2. https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM