Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions src/c2pa/c2pa.py
Comment thread
tmathern marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
import os
import warnings
import weakref
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union, Callable, Any, overload
Expand Down Expand Up @@ -1614,6 +1615,10 @@ def __init__(self, file_like_stream):

self._file_like_stream = file_like_stream

# Weakref breaks avoids references cycles:
# Stream -> ctypes_cb -> closure -> Stream.
_weak_self = weakref.ref(self)

def read_callback(ctx, data, length):
"""Callback function for reading data from the Python stream.

Expand All @@ -1632,13 +1637,16 @@ def read_callback(ctx, data, length):
Returns:
Number of bytes read, or -1 on error
"""
if not self._initialized or self._closed:
s = _weak_self()
if s is None:
return -1
if not s._initialized or s._closed:
return -1
try:
if not data or length <= 0:
return -1

stream = self._file_like_stream
stream = s._file_like_stream
readinto = getattr(stream, "readinto", None)
if readinto is not None:
# Most streams have readinto
Expand Down Expand Up @@ -1689,8 +1697,11 @@ def seek_callback(ctx, offset, whence):
Returns:
New position in the stream, or -1 on error
"""
file_stream = self._file_like_stream
if not self._initialized or self._closed:
s = _weak_self()
if s is None:
return -1
file_stream = s._file_like_stream
if not s._initialized or s._closed:
return -1
try:
# Fall back to tell() only for stream objects that do not
Expand Down Expand Up @@ -1718,13 +1729,16 @@ def write_callback(ctx, data, length):
Returns:
Number of bytes written, or -1 on error
"""
if not self._initialized or self._closed:
s = _weak_self()
if s is None:
return -1
if not s._initialized or s._closed:
return -1
try:
if not data or length <= 0:
return -1

self._file_like_stream.write(ctypes.string_at(data, length))
s._file_like_stream.write(ctypes.string_at(data, length))
return length
except Exception:
return -1
Expand All @@ -1743,10 +1757,13 @@ def flush_callback(ctx):
Returns:
0 on success, -1 on error
"""
if not self._initialized or self._closed:
s = _weak_self()
if s is None:
return -1
if not s._initialized or s._closed:
return -1
try:
self._file_like_stream.flush()
s._file_like_stream.flush()
return 0
except Exception:
return -1
Expand Down
63 changes: 63 additions & 0 deletions tests/test_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6255,5 +6255,68 @@ def test_sign_callback_signer_in_ctx(self):
context.close()


class TestStreamReferences(unittest.TestCase):

def test_stream_collected_after_del(self):
"""Stream must be collected by reference counting."""
import gc
import weakref
gc.collect()
garbage_before = len(gc.garbage)

buf = io.BytesIO(b"hello world")
s = Stream(buf)
ref = weakref.ref(s)
del s

# Trigger gc, we want to verify it's collected
# collected now (1 gc call) means no gc cycle breaker needed
# aka ref cycle did not happen.
gc.collect()

self.assertIsNone(ref(), "Stream not collected")
self.assertEqual(len(gc.garbage), garbage_before,
"Stream added objects to gc.garbage")

def test_stream_not_added_to_gc_garbage_list(self):
"""Creating and dropping many Streams must not grow gc.garbage."""
import gc
gc.collect()
gc.garbage.clear()

for _ in range(20):
s = Stream(io.BytesIO(b"data"))
del s

gc.collect()
self.assertEqual(len(gc.garbage), 0,
f"gc.garbage unexpectedly non-empty: {gc.garbage}")

def test_callbacks_return_minus_one_after_stream_collected(self):
"""Callbacks must return -1 gracefully when the Stream has been GC'd."""
import gc
import weakref

buf = io.BytesIO(b"test")
s = Stream(buf)

read_cb = s._read_cb
seek_cb = s._seek_cb
write_cb = s._write_cb
flush_cb = s._flush_cb
ref = weakref.ref(s)
del s
gc.collect()

# Stream should be gone.
self.assertIsNone(ref(), "Stream not collected before callback test")

# All callbacks must return -1 without crashing.
self.assertEqual(read_cb(None, None, 0), -1)
self.assertEqual(seek_cb(None, 0, 0), -1)
self.assertEqual(write_cb(None, None, 0), -1)
self.assertEqual(flush_cb(None), -1)


if __name__ == '__main__':
unittest.main(warnings='ignore')
Loading