[英]How to convert tensorflow.js model weights to pytorch tensors, and back?
I am using ml5.js , a wrapper around tensorflowjs.我正在使用ml5.js ,它是 tensorflowjs 的包装器。 I want to train a neural network in the browser, download the weights, process them as tensors in pyTorch, and load them back into the browser's tensorflowjs model.我想在浏览器中训练一个神经网络,下载权重,在 pyTorch 中将它们处理为张量,然后将它们加载回浏览器的 tensorflowjs model。 How do I convert between these formats tfjs <-> pytorch
?如何在这些格式之间转换tfjs <-> pytorch
?
The browser model has a save()
function which generates three files.浏览器 model 有一个save()
function 生成三个文件。 A metadata file specific to ml5.js (json), a topology file describing model architecture (json), and a binary weights file (bin).特定于 ml5.js (json) 的元数据文件、描述 model 架构的拓扑文件 (json) 和二进制权重文件 (bin)。
// Browser
model.save()
// HTTP/Download
model_meta.json (needed by ml5.js)
model.json (needed by tfjs)
model.weights.bin (needed by tfjs)
# python backend
import json
with open('model.weights.bin', 'rb') as weights_file:
with open('model.json', 'rb') as model_file:
weights = weights_file.read()
model = json.loads(model_file.read())
####
pytorch_tensor = convert2tensor(weights, model) # whats in this function?
####
# Do some processing in pytorch
####
new_weights_bin = convert2bin(pytorch_tensor, model) # and in this?
####
Here is sample javascript code to generate and load the 3 files in the browser.这是用于在浏览器中生成和加载 3 个文件的示例 javascript 代码。 To load, select all 3 files at once in the dialog box.要在对话框中一次加载 select 所有 3 个文件。 If they are correct, a popup will show a sample prediction.如果它们是正确的,弹出窗口将显示一个示例预测。
I was able to find a way to convert from tfjs model.weights.bin
to numpy's ndarrays
.我能够找到一种将 tfjs model.weights.bin
转换为 numpy 的ndarrays
的方法。 It is trivial to convert from numpy arrays to pytorch state_dict
which is a dictionary of tensors and their names.从 numpy arrays 转换为 pytorch state_dict
是很简单的,它是张量及其名称的字典。
First, the tfjs representation of the model should be understood.首先,应该了解model的tfjs表示。 model.json
describes the model. model.json
描述了 model。 In python, it can be read as a dictionary.在 python 中,它可以作为字典阅读。 It has the following keys:它有以下键:
The model architecture is described as another json/dictionary under the key modelTopology
. model 架构被描述为关键modelTopology
下的另一个 json/字典。
It also has a json/dictionary under the key weightsManifest
which describes the type/shape/location of each weight wrapped up in the corresponding model.weights.bin
file.它在weightsManifest
键下还有一个 json/字典,它描述了包含在相应model.weights.bin
文件中的每个权重的类型/形状/位置。 As an aside, the weights manifest allows for multiple .bin
files to store weights.顺便说一句,权重清单允许多个.bin
文件存储权重。
Tensorflow.js has a companion python package tensorflowjs
, which comes with utility functions to read and write weights between the tf.js binary and numpy array format. Tensorflow.js has a companion python package tensorflowjs
, which comes with utility functions to read and write weights between the tf.js binary and numpy array format.
Each weight file is read as a "group".每个权重文件被读取为一个“组”。 A group is a list of dictionaries with keys name
and data
which refer to the weight name and the numpy array containing weights.组是具有键name
和data
的字典列表,它们引用权重名称和包含权重的 numpy 数组。 There are optionally other keys too.还有可选的其他键。
group = [{'name': weight_name, 'data': np.ndarray}, ...] # 1 *.bin file
Install tensorflowjs.安装张量流。 Unfortunately, it will also install tensorflow.不幸的是,它还会安装 tensorflow。
pip install tensorflowjs
Use these functions.使用这些功能。 Note that I changed the signatures for convenience.请注意,为方便起见,我更改了签名。
from typing import Dict, ByteString
import torch
from tensorflowjs.read_weights import decode_weights
from tensorflowjs.write_weights import write_weights
def convert2tensor(weights: ByteString, model: Dict) -> Dict[str, torch.Tensor]:
manifest = model['weightsManifest']
# If flatten=False, returns a list of groups equal to the number of .bin files.
# Use flatten=True to convert to a single group
group = decode_weights(manifest, weights, flatten=True)
# Convert dicts in tfjs group format into pytorch's state_dict format:
# {name: str, data: ndarray} -> {name: tensor}
state_dict = {d['name']: torch.from_numpy(d['data']) for d in group}
return state_dict
def convert2bin(state_dict: Dict[str: np.ndarray], model: Dict, directory='./'):
# convert state_dict to groups (list of 1 group)
groups = [[{'name': key, 'data': value} for key, value in state_dict.items()]]
# this library function will write to .bin file[s], but you can read it back
# or change the function internals my copying them from source
write_weights(groups, directory, write_manifest=False)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.