Skip to content

Commit

Permalink
Make BinPickler a pickler
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Dec 19, 2023
1 parent 338ec34 commit 5554b66
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions binpickle/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _align_pos(pos: int, size: int = mmap.PAGESIZE) -> int:
return pos


class BinPickler:
class BinPickler(pickle.Pickler):
"""
Save an object into a binary pickle file. This is like :class:`pickle.Pickler`,
except it works on file paths instead of byte streams.
Expand All @@ -38,6 +38,9 @@ class BinPickler:
with BinPickler('file.bpk') as bpk:
bpk.dump(obj)
Only one object can be dumped to a `BinPickler`. Other methods are exposed
for manually constructing BinPickle files but their use is highly discouraged.
Args:
filename(str or pathlib.Path):
The path to the file to write.
Expand All @@ -58,6 +61,7 @@ class BinPickler:
to vary from buffer to buffer.
"""

_pickle_stream: io.BytesIO
filename: str | PathLike[str]
align: bool
codecs: list[ResolvedCodec]
Expand All @@ -71,11 +75,17 @@ def __init__(
align: bool = False,
codecs: Optional[list[CodecArg]] = None,
):
self._pickle_stream = io.BytesIO()
self.filename = filename
self.align = align
self._file = open(filename, "wb")
self.entries = []

# set up the binpickler
super().__init__(
self._pickle_stream, pickle.HIGHEST_PROTOCOL, buffer_callback=self._write_buffer
)

if codecs is None:
self.codecs = []
else:
Expand All @@ -96,12 +106,8 @@ def compressed(cls, filename: str | PathLike[str], codec: CodecArg = "gzip"):

def dump(self, obj: object) -> None:
"Dump an object to the file. Can only be called once."
bio = io.BytesIO()
pk = pickle.Pickler(
bio, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=self._write_buffer
)
pk.dump(obj)
buf = bio.getbuffer()
super().dump(obj)
buf = self._pickle_stream.getbuffer()

tot_enc = sum(e.enc_length for e in self.entries)
tot_dec = sum(e.dec_length for e in self.entries)
Expand Down

0 comments on commit 5554b66

Please sign in to comment.