简体   繁体   中英

Python: stream a tarfile to S3 using multipart upload

I would like to create a .tar file in an S3 bucket from Python code running in an AWS Lambda function. Lambda functions are very memory- and disk- constrained. I want to create a .tar file that contains multiple files that are too large to fit in the Lambda function's memory or disk space.

Using "S3 multipart upload," it is possible to upload a large file by uploading chunks of 5MB or more in size. I have this figured out and working. What I need to figure out is how to manage a buffer of bytes in memory that won't grow past the limits of the Lambda function's runtime environment.

I think the solution is to create an io.BytesIO() object and manage both a read pointer and a write pointer. I can then write into the buffer (from files that I want to add to the .tar file) and every time the buffer exceeds some limit (like 5MB) I can read off a chunk of data and send another file part to S3.

What I haven't quite wrapped my head around is how to truncate the part of the buffer that has been read and is no longer needed in memory. I need to trim the head of the buffer, not the tail, so the truncate() function of BytesIO won't work for me.

Is the 'correct' solution to create a new BytesIO buffer, populating it with the contents of the existing buffer from the read pointer to the end of the buffer, when I truncate? Is there a better way to truncate the head of the BytesIO buffer? Is there a better solution than using BytesIO?

For the random Google-r who stumbles onto this question six years in the future and thinks, "man, that describes my problem exactly!", here's what I came up with:

import io
import struct
from tarfile import BLOCKSIZE

#This class was designed to write a .tar file to S3 using multipart upload
#in a memory- and disk constrained environment, such as AWS Lambda Functions.
#
#Much of this code is copied or adapted from the Python source code tarfile.py
#file at https://github.com/python/cpython/blob/3.10/Lib/tarfile.py
#
#No warranties expressed or implied. Your mileage may vary. Lather, rinse, repeat

class StreamingTarFileWriter:
    #Various constants from tarfile.py that we need
    GNU_FORMAT = 1
    NUL = b"\0" 
    BLOCKSIZE = 512       
    RECORDSIZE = BLOCKSIZE * 20 

    class MemoryByteStream:
        def __init__(self, bufferFullCallback = None, bufferFullByteCount = 0):
            self.buf = io.BytesIO()
            self.readPointer = 0
            self.writePointer = 0
            self.bufferFullCallback = bufferFullCallback
            self.bufferFullByteCount = bufferFullByteCount

        def write(self, buf: bytes):
            self.buf.seek(self.writePointer)
            self.writePointer += self.buf.write(buf)
            bytesAvailableToRead = self.writePointer - self.readPointer
            if self.bufferFullByteCount > 0 and bytesAvailableToRead > self.bufferFullByteCount:
                if self.bufferFullCallback:
                    self.bufferFullCallback(self, bytesAvailableToRead)

        def read(self, byteCount = None):
            self.buf.seek(self.readPointer)
            if byteCount:
                chunk = self.buf.read(byteCount)
            else: 
                chunk = self.buf.read()
            self.readPointer += len(chunk)
            self._truncate()
            return chunk

        def size(self):
            return  self.writePointer - self.readPointer

        def _truncate(self):
            self.buf.seek(self.readPointer)
            self.buf = io.BytesIO(self.buf.read())
            self.readPointer = 0
            self.writePointer = self.buf.seek(0, 2)

    def stn(self, s, length, encoding, errors):
        #Convert a string to a null-terminated bytes object.
        s = s.encode(encoding, errors)
        return s[:length] + (length - len(s)) * self.NUL

    def itn(self, n, digits=8, format=GNU_FORMAT):
        #Convert a python number to a number field.
        # POSIX 1003.1-1988 requires numbers to be encoded as a string of
        # octal digits followed by a null-byte, this allows values up to
        # (8**(digits-1))-1. GNU tar allows storing numbers greater than
        # that if necessary. A leading 0o200 or 0o377 byte indicate this
        # particular encoding, the following digits-1 bytes are a big-endian
        # base-256 representation. This allows values up to (256**(digits-1))-1.
        # A 0o200 byte indicates a positive number, a 0o377 byte a negative
        # number.
        original_n = n
        n = int(n)
        if 0 <= n < 8 ** (digits - 1):
            s = bytes("%0*o" % (digits - 1, n), "ascii") + self.NUL
        elif format == self.GNU_FORMAT and -256 ** (digits - 1) <= n < 256 ** (digits - 1):
            if n >= 0:
                s = bytearray([0o200])
            else:
                s = bytearray([0o377])
                n = 256 ** digits + n

            for i in range(digits - 1):
                s.insert(1, n & 0o377)
                n >>= 8
        else:
            raise ValueError("overflow in number field")

        return s

    def calc_chksums(self, buf):
        """Calculate the checksum for a member's header by summing up all
        characters except for the chksum field which is treated as if
        it was filled with spaces. According to the GNU tar sources,
        some tars (Sun and NeXT) calculate chksum with signed char,
        which will be different if there are chars in the buffer with
        the high bit set. So we calculate two checksums, unsigned and
        signed.
        """
        unsigned_chksum = 256 + sum(struct.unpack_from("148B8x356B", buf))
        signed_chksum = 256 + sum(struct.unpack_from("148b8x356b", buf))
        return unsigned_chksum, signed_chksum

    def __init__(self, bufferFullCallback = None, bufferFullByteCount = 0):
        self.buf = self.MemoryByteStream(bufferFullCallback, bufferFullByteCount)
        self.expectedFileSize = 0
        self.fileBytesWritten = 0
        self.offset = 0
        pass

    def addFileRecord(self, filename, filesize):
        REGTYPE = b"0"                  # regular file
        encoding = "utf-8"
        LENGTH_NAME = 100
        GNU_MAGIC = b"ustar  \0"        # magic gnu tar string 
        errors="surrogateescape"

        #Copied from TarInfo.tobuf()
        tarinfo = {
            "name":     filename,
            "mode":     0o644,
            "uid":      0,
            "gid":      0,
            "size":     filesize,
            "mtime":    0,
            "chksum":   0,
            "type":     REGTYPE,
            "linkname": "",
            "uname":    "",
            "gname":    "",
            "devmajor": 0,
            "devminor": 0,
            "magic":    GNU_MAGIC
        }

        buf = b""
        if len(tarinfo["name"].encode(encoding, errors)) > LENGTH_NAME:
            raise Exception("Filename is too long for tar file header.")

        devmajor = self.stn("", 8, encoding, errors)
        devminor = self.stn("", 8, encoding, errors)

        parts = [
            self.stn(tarinfo.get("name", ""), 100, encoding, errors),
            self.itn(tarinfo.get("mode", 0) & 0o7777, 8, self.GNU_FORMAT),
            self.itn(tarinfo.get("uid", 0), 8, self.GNU_FORMAT),
            self.itn(tarinfo.get("gid", 0), 8, self.GNU_FORMAT),
            self.itn(tarinfo.get("size", 0), 12, self.GNU_FORMAT),
            self.itn(tarinfo.get("mtime", 0), 12, self.GNU_FORMAT),
            b"        ", # checksum field
            tarinfo.get("type", REGTYPE),
            self.stn(tarinfo.get("linkname", ""), 100, encoding, errors),
            tarinfo.get("magic", GNU_MAGIC),
            self.stn(tarinfo.get("uname", ""), 32, encoding, errors),
            self.stn(tarinfo.get("gname", ""), 32, encoding, errors),
            devmajor,
            devminor,
            self.stn(tarinfo.get("prefix", ""), 155, encoding, errors)
        ]
        buf = struct.pack("%ds" % BLOCKSIZE, b"".join(parts))
        chksum = self.calc_chksums(buf[-BLOCKSIZE:])[0]
        buf = buf[:-364] + bytes("%06o\0" % chksum, "ascii") + buf[-357:]
        self.buf.write(buf)
        self.expectedFileSize = filesize
        self.fileBytesWritten = 0
        self.offset += len(buf)

    def addFileData(self, buf):
        self.buf.write(buf)
        self.fileBytesWritten += len(buf)
        self.offset += len(buf)

        pass

    def completeFileRecord(self):
        if self.fileBytesWritten != self.expectedFileSize:
            raise Exception(f"Expected {self.expectedFileSize:,} bytes but {self.fileBytesWritten:,} were written.")

        #write the end-of-file marker
        blocks, remainder = divmod(self.fileBytesWritten, BLOCKSIZE)
        if remainder > 0:
            self.buf.write(self.NUL * (BLOCKSIZE - remainder))
            self.offset += BLOCKSIZE - remainder
            

    def completeTarFile(self):
        self.buf.write(self.NUL * (BLOCKSIZE * 2))
        self.offset += (BLOCKSIZE * 2)
        blocks, remainder = divmod(self.offset, self.RECORDSIZE)
        if remainder > 0:
            self.buf.write(self.NUL * (self.RECORDSIZE - remainder))

An example use of the class is:

OUTPUT_CHUNK_SIZE = 1024 * 1024 * 5
f_out = open("test.tar", "wb")

def get_file_block(blockNum):
    block = f"block_{blockNum:010,}"
    block += "0123456789abcdef" * 31
    return bytes(block, 'ascii')

def buffer_full_callback(x: StreamingTarFileWriter.MemoryByteStream, bytesAvailable: int):
    while x.size() > OUTPUT_CHUNK_SIZE:
        buf = x.read(OUTPUT_CHUNK_SIZE)
        #This is where you would write the chunk to S3
        f_out.write(buf)

x = StreamingTarFileWriter(buffer_full_callback, OUTPUT_CHUNK_SIZE)

import random
numFiles = random.randint(3,8)
print(f"Creating {numFiles:,} files.")

for fileIdx in range(numFiles):
    minSize = 1025 #1kB plus 1 byte
    maxSize = 10 * 1024 * 1024 * 1024 + 5 #10GB plus 5 bytes
    numBytes = random.randint(minSize, maxSize)
    print(f"Creating file {str(fileIdx)} with {numBytes:,} bytes.")
    blocks,remainder = divmod(numBytes, 512)

    x.addFileRecord(f"File{str(fileIdx)}", numBytes)
    for block in range(blocks):
        x.addFileData(get_file_block(block))
    x.addFileData(bytes(("X" * remainder), 'ascii'))
    x.completeFileRecord()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM