簡體   English   中英

如何將我自己的類對象存儲到 hdf5 中?

[英]How to store my own class object into hdf5?

我創建了一個類來保存我的研究(我是 EE 博士生)的實驗結果,例如

class Trial:
    def __init__(self, subID, triID):
        self.filePath = '' # file path of the folder
        self.subID = -1    # int
        self.triID = -1    # int
        self.data_A = -1   # numpy array
        self.data_B = -1   # numpy array
        ......

它混合了許多 bool、int 和 numpy 數組。 你明白了。 我讀到如果數據為 hdf5 格式,加載速度會更快。 我可以用我的數據來做,它是我的Trial對象的 python 列表嗎?

請注意,stackoverflow 上有一個類似的問題 但它只有一個答案,它不能回答問題。 相反,它將 OP 的自定義類分解為基本數據類型並將它們存儲到單獨的數據集中。 我不反對這樣做,但我想知道這是否是唯一的方法,因為它違反了面向對象的哲學。

這是我用來保存這樣的數據的一個小類。 你可以通過做類似的事情來使用它..

dc = DataContainer()
dc.trials = <your list of trial objects here>
dc.save('mydata.pkl')

然后加載做..

dc = DataContainer.load('mydata.pkl')

這是數據容器文件:

import gzip
import cPickle as pickle

# Simple container with load and save methods.  Declare the container
# then add data to it.  Save will save any data added to the container.
# The class automatically gzips the file if it ends in .gz
#
# Notes on size and speed (using UbuntuDialog data)
#       pkl     pkl.gz
# Save  11.4s   83.7s
# Load   4.8s   45.0s
# Size  596M    205M
#
class DataContainer(object):
    @staticmethod
    def isGZIP(filename):
        if filename.split('.')[-1] == 'gz':
            return True
        return False

    # Using HIGHEST_PROTOCOL is almost 2X faster and creates a file that
    # is ~10% smaller.  Load times go down by a factor of about 3X.
    def save(self, filename='DataContainer.pkl'):
        if self.isGZIP(filename):
            f = gzip.open(filename, 'wb')
        else:
            f = open(filename, 'wb')
        pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
        f.close()

    # Note that loading to a string with pickle.loads is about 10% faster
    # but probaly comsumes a lot more memory so we'll skip that for now.
    @classmethod
    def load(cls, filename='DataContainer.pkl'):
        if cls.isGZIP(filename):
            f = gzip.open(filename, 'rb')
        else:
            f = open(filename, 'rb')
        n = pickle.load(f)
        f.close()
        return n

根據您的用例,您可以將其用作頂部描述的基類,或者簡單地將 pickle.dump 行復制到您的代碼中。

如果您確實有很多數據,並且沒有在每次運行測試程序時都使用所有數據,那么還有一些其他選項,例如數據庫,但假設您需要大部分數據,以上是關於最好的簡單選項每次運行。

我沒有測試以下解決方案的速度和存儲效率。 HDF5 確實支持“復合數據類型”,可以與支持混合變量類型的 numpy“結構化數組”一起使用,例如在您的類對象中遇到的。

"""
Created on Tue Dec 10 21:26:54 2019

@author: Christopher J. Burke
Give a worked example of saving a list of class objects with mixed
storage types to a HDF5 file and reading in file back to a list of class
objects.  The solution is inspired by this bug report
https://github.com/h5py/h5py/issues/735
and the numpy and hdf5 documentation
"""

import numpy as np
import h5py

class test_object:
    """ Define a storage class that keeps info that we want to record
      for every object
    """
    # explictly state the name, datatype and shape for every
    #  class variable
    #  The names MUST exactly match the class variable names in the __init__
    store_names = ['a', 'b', 'c', 'd', 'e']
    store_types = ['i8', 'i4', 'f8', 'S80', 'f8']
    store_shapes = [None, None, None, None, [4]]
    # Make the tuples that will define the numpy structured array
    # https://docs.scipy.org/doc/numpy/user/basics.rec.html
    sz = len(store_names)
    store_def_tuples = []
    for i in range(sz):
        if store_shapes[i] is not None:
            store_def_tuples.append((store_names[i], store_types[i], store_shapes[i]))
        else:
            store_def_tuples.append((store_names[i], store_types[i]))
    # Actually define the numpy structured/compound data type
    store_struct_numpy_dtype = np.dtype(store_def_tuples)

    def __init__(self):
        self.a = 0
        self.b = 0
        self.c = 0.0
        self.d = '0'
        self.e = [0.0, 0.0, 0.0, 0.0]

    def store_objlist_as_hd5f(self, objlist, fileName):
        """Function to save the class structure into hdf5
        objlist -  is a list of the test_objects
        fileName - is the h5 filename for output
        """        
        # First create the array of numpy structered arrays
        np_dset = np.ndarray(len(objlist), dtype=self.store_struct_numpy_dtype)
        # Convert the class variables into the numpy structured dtype
        for i, curobj in enumerate(objlist):
            for j in range(len(self.store_names)):
                np_dset[i][self.store_names[j]] = getattr(curobj, self.store_names[j])
        # Data set should be all loaded ready to write out
        fp = h5py.File(fileName, 'w')
        hf_dset = fp.create_dataset('dset', shape=(len(objlist),), dtype=self.store_struct_numpy_dtype)
        hf_dset[:] = np_dset
        fp.close()

    def fill_objlist_from_hd5f(self, fileName):
        """ Function to read in the hdf5 file created by store_objlist_as_hdf5
          and store the contents into a list of test_objects
          fileName - si the h5 filename for input
         """
        fp = h5py.File(fileName, 'r')
        np_dset = np.array(fp['dset'])
        # Start with empty list
        all_objs = []
        # iterate through the numpy structured array and save to objects
        for i in range(len(np_dset)):
            tmp = test_object()
            for j in range(len(self.store_names)):
                setattr(tmp, self.store_names[j], np_dset[i][self.store_names[j]])
            # Append object to list
            all_objs.append(tmp)
        return all_objs

if __name__ == '__main__':

    all_objs = []    
    for i in range(3):
        # instantiate tce_seed object
        tmp = test_object()
        # Put in some dummy data into object
        tmp.a = int(i)
        tmp.b = int(i)
        tmp.c = float(i)
        tmp.d = '{0} {0} {0} {0}'.format(i)
        tmp.e = np.full([4], i, dtype=np.float)
        all_objs.append(tmp)

    # Write out hd5 file
    tmp.store_objlist_as_hd5f(all_objs, 'test_write.h5')

    # Read in hd5 file
    all_objs = []
    all_objs = tmp.fill_objlist_from_hd5f('test_write.h5')

    # verify the output is as expected
    for i, curobj in enumerate(all_objs):
        print('Object {0:d}'.format(i))
        print('{0:d} {1:d} {2:f}'.format(curobj.a, curobj.b, curobj.c))
        print('{0} {1}'.format(curobj.d.decode('ASCII'), curobj.e))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM