簡體   English   中英

無法使用 Tensorflow.JS 轉換 GPT-2 model

[英]Cannot convert GPT-2 model using Tensorflow.JS

我正在嘗試在 Node.JS 項目上加載 GPT-2 model。 我相信這可以使用 tfjs 庫來完成。 所以我嘗試將 GPT-2 model 轉換為 tfjs model。 根據對此答案的建議,我將 GPT-2 model 導出為 SavedModel。

!python3 -m pip install -q git+https://github.com/huggingface/transformers.git
!python3 -m pip install tensorflow tensorflowjs

然后運行以下代碼以導出 SavedModel xx.pb 文件。

from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
import tensorflowjs
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# add the EOS token as PAD token to avoid warnings
model = TFGPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
model.save("./test_gpt2")

然后我運行此命令將 SavedModel 轉換為 tfjs 兼容文件。

!tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='gpt2' \
    --saved_model_tags=serve \
    /content/test_gpt2 \
    /content/test_gpt2_web_model

這會導致錯誤

2020-07-08 16:36:11.455383: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-07-08 16:36:11.459979: I tensorflow/core/platform/profile_utils/cpu_utils.cc:102] CPU Frequency: 2300000000 Hz
2020-07-08 16:36:11.460216: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x2e5b100 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-07-08 16:36:11.460284: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-07-08 16:36:18.337463: I tensorflow/core/grappler/devices.cc:60] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 (Note: TensorFlow was not compiled with CUDA support)
2020-07-08 16:36:18.337631: I tensorflow/core/grappler/clusters/single_machine.cc:356] Starting new session
2020-07-08 16:36:18.536301: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:797] Optimization results for grappler item: graph_to_optimize
2020-07-08 16:36:18.536373: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:799]   function_optimizer: Graph size after: 163 nodes (0), 175 edges (0), time = 43.871ms.
2020-07-08 16:36:18.536384: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:799]   function_optimizer: Graph size after: 163 nodes (0), 175 edges (0), time = 50.779ms.
2020-07-08 16:36:18.536393: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:797] Optimization results for grappler item: __inference__wrapped_model_24863
2020-07-08 16:36:18.536402: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:799]   function_optimizer: function_optimizer did nothing. time = 0.004ms.
2020-07-08 16:36:18.536411: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:799]   function_optimizer: function_optimizer did nothing. time = 0ms.
Traceback (most recent call last):
  File "/usr/local/bin/tensorflowjs_converter", line 8, in <module>
    sys.exit(pip_main())
  File "/usr/local/lib/python3.6/dist-packages/tensorflowjs/converters/converter.py", line 735, in pip_main
    main([' '.join(sys.argv[1:])])
  File "/usr/local/lib/python3.6/dist-packages/tensorflowjs/converters/converter.py", line 739, in main
    convert(argv[0].split(' '))
  File "/usr/local/lib/python3.6/dist-packages/tensorflowjs/converters/converter.py", line 681, in convert
    control_flow_v2=args.control_flow_v2)
  File "/usr/local/lib/python3.6/dist-packages/tensorflowjs/converters/tf_saved_model_conversion_v2.py", line 494, in convert_tf_saved_model
    weight_shard_size_bytes=weight_shard_size_bytes)
  File "/usr/local/lib/python3.6/dist-packages/tensorflowjs/converters/tf_saved_model_conversion_v2.py", line 143, in optimize_graph
    ', '.join(unsupported))
ValueError: Unsupported Ops in the model before optimization
StatefulPartitionedCall

它說StatefulPartitionedCall不受支持。 有沒有辦法解決這個問題?

根據該項目的 github 頁面,看起來這個問題已通過 TensorFlow 的 2.1.0 版本得到解決。 我建議升級您正在使用的 TF 版本,看看它是否有效。

https://github.com/tensorflow/tfjs/issues/3582#issuecomment-668302536

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM