简体   繁体   English

如何将 tensorflow.js model 权重转换为 pytorch 张量,然后返回?

[英]How to convert tensorflow.js model weights to pytorch tensors, and back?

I am using ml5.js , a wrapper around tensorflowjs.我正在使用ml5.js ,它是 tensorflowjs 的包装器。 I want to train a neural network in the browser, download the weights, process them as tensors in pyTorch, and load them back into the browser's tensorflowjs model.我想在浏览器中训练一个神经网络,下载权重,在 pyTorch 中将它们处理为张量,然后将它们加载回浏览器的 tensorflowjs model。 How do I convert between these formats tfjs <-> pytorch ?如何在这些格式之间转换tfjs <-> pytorch

The browser model has a save() function which generates three files.浏览器 model 有一个save() function 生成三个文件。 A metadata file specific to ml5.js (json), a topology file describing model architecture (json), and a binary weights file (bin).特定于 ml5.js (json) 的元数据文件、描述 model 架构的拓扑文件 (json) 和二进制权重文件 (bin)。

// Browser
model.save()
// HTTP/Download
model_meta.json   (needed by ml5.js)
model.json        (needed by tfjs)
model.weights.bin (needed by tfjs)
# python backend
import json

with open('model.weights.bin', 'rb') as weights_file:
    with open('model.json', 'rb') as model_file:
        weights = weights_file.read()
        model = json.loads(model_file.read())
        ####
        pytorch_tensor = convert2tensor(weights, model) # whats in this function?
        ####
        # Do some processing in pytorch

        ####
        new_weights_bin = convert2bin(pytorch_tensor, model) # and in this?
        ####

Here is sample javascript code to generate and load the 3 files in the browser.这是用于在浏览器中生成和加载 3 个文件的示例 javascript 代码 To load, select all 3 files at once in the dialog box.要在对话框中一次加载 select 所有 3 个文件。 If they are correct, a popup will show a sample prediction.如果它们是正确的,弹出窗口将显示一个示例预测。

I was able to find a way to convert from tfjs model.weights.bin to numpy's ndarrays .我能够找到一种将 tfjs model.weights.bin转换为 numpy 的ndarrays的方法。 It is trivial to convert from numpy arrays to pytorch state_dict which is a dictionary of tensors and their names.从 numpy arrays 转换为 pytorch state_dict是很简单的,它是张量及其名称的字典。

Theory理论

First, the tfjs representation of the model should be understood.首先,应该了解model的tfjs表示。 model.json describes the model. model.json描述了 model。 In python, it can be read as a dictionary.在 python 中,它可以作为字典阅读。 It has the following keys:它有以下键:

  1. The model architecture is described as another json/dictionary under the key modelTopology . model 架构被描述为关键modelTopology下的另一个 json/字典。

  2. It also has a json/dictionary under the key weightsManifest which describes the type/shape/location of each weight wrapped up in the corresponding model.weights.bin file.它在weightsManifest键下还有一个 json/字典,它描述了包含在相应model.weights.bin文件中的每个权重的类型/形状/位置。 As an aside, the weights manifest allows for multiple .bin files to store weights.顺便说一句,权重清单允许多个.bin文件存储权重。

Tensorflow.js has a companion python package tensorflowjs , which comes with utility functions to read and write weights between the tf.js binary and numpy array format. Tensorflow.js has a companion python package tensorflowjs , which comes with utility functions to read and write weights between the tf.js binary and numpy array format.

Each weight file is read as a "group".每个权重文件被读取为一个“组”。 A group is a list of dictionaries with keys name and data which refer to the weight name and the numpy array containing weights.组是具有键namedata的字典列表,它们引用权重名称和包含权重的 numpy 数组。 There are optionally other keys too.还有可选的其他键。

group = [{'name': weight_name, 'data': np.ndarray}, ...]   # 1 *.bin file

Application应用

Install tensorflowjs.安装张量流。 Unfortunately, it will also install tensorflow.不幸的是,它还会安装 tensorflow。

pip install tensorflowjs

Use these functions.使用这些功能。 Note that I changed the signatures for convenience.请注意,为方便起见,我更改了签名。

from typing import Dict, ByteString
import torch
from tensorflowjs.read_weights import decode_weights
from tensorflowjs.write_weights import write_weights

def convert2tensor(weights: ByteString, model: Dict) -> Dict[str, torch.Tensor]:
    manifest = model['weightsManifest']
    # If flatten=False, returns a list of groups equal to the number of .bin files.
    # Use flatten=True to convert to a single group
    group = decode_weights(manifest, weights, flatten=True)
    # Convert dicts in tfjs group format into pytorch's state_dict format:
    # {name: str, data: ndarray} -> {name: tensor}
    state_dict = {d['name']: torch.from_numpy(d['data']) for d in group}
    return state_dict

def convert2bin(state_dict: Dict[str: np.ndarray], model: Dict, directory='./'):
    # convert state_dict to groups (list of 1 group)
    groups = [[{'name': key, 'data': value} for key, value in state_dict.items()]]
    # this library function will write to .bin file[s], but you can read it back
    # or change the function internals my copying them from source
    write_weights(groups, directory, write_manifest=False)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM