From 144146d5252248efa1db55fb5a56724d91f47b7c Mon Sep 17 00:00:00 2001
From: slozier <slozier@users.noreply.github.com>
Date: Fri, 31 Jan 2025 20:30:53 -0500
Subject: [PATCH] Add IBufferProtocol to mmap (#1866)

* Add IBufferProtocol to mmap

* Clean up interlocked or

* Add missing TryAddRef case
---
 src/core/IronPython.Modules/mmap.cs | 123 +++++++++++++++++++++++-----
 src/core/IronPython.Modules/re.cs   |   8 --
 2 files changed, 103 insertions(+), 28 deletions(-)

diff --git a/src/core/IronPython.Modules/mmap.cs b/src/core/IronPython.Modules/mmap.cs
index 8453bcb8e..56f99230b 100644
--- a/src/core/IronPython.Modules/mmap.cs
+++ b/src/core/IronPython.Modules/mmap.cs
@@ -7,6 +7,8 @@
 #if FEATURE_MMAP
 
 using System;
+using System.Buffers;
+using System.Collections.Generic;
 using System.Diagnostics;
 using System.Globalization;
 using System.IO;
@@ -278,7 +280,7 @@ private static MemoryMappedFileAccess ToMmapFileAccess(int access) {
         }
 
         [PythonHidden]
-        public class MmapDefault : IWeakReferenceable {
+        public class MmapDefault : IWeakReferenceable, IBufferProtocol {
             private MemoryMappedFile _file;
             private MemoryMappedViewAccessor _view;
             private long _position;
@@ -599,6 +601,11 @@ private bool TryAddRef(bool exclusive, out int reason) {
                         reason = StateBits.Exporting;
                         return false;
                     }
+                    if (exclusive && ((oldState & StateBits.RefCount) > StateBits.RefCountOne)) {
+                        // mmap in non-exclusive use, temporarily no exclusive use allowed
+                        reason = StateBits.Exclusive;
+                        return false;
+                    }
                     Debug.Assert((oldState & StateBits.RefCount) > 0, "resurrecting disposed mmap object (disposed without being closed)");
 
                     newState = oldState + StateBits.RefCountOne;
@@ -635,6 +642,9 @@ private void Release(bool exclusive) {
                     if (exclusive) {
                         newState &= ~StateBits.Exclusive;
                     }
+                    if ((newState & StateBits.RefCount) == StateBits.RefCountOne) {
+                        newState &= ~StateBits.Exporting;
+                    }
                 } while (Interlocked.CompareExchange(ref _state, newState, oldState) != oldState);
 
                 if (performDispose) {
@@ -648,25 +658,28 @@ private void Release(bool exclusive) {
                 }
             }
 
+            private int InterlockedOrState(int value) {
+#if NET5_0_OR_GREATER
+                return Interlocked.Or(ref _state, value);
+#else
+                int current = _state;
+                while (true) {
+                    int newValue = current | value;
+                    int oldValue = Interlocked.CompareExchange(ref _state, newValue, current);
+                    if (oldValue == current) {
+                        return oldValue;
+                    }
+                    current = oldValue;
+                }
+#endif
+            }
 
             public void close() {
                 // close is idempotent; it must never block
-#if NET5_0_OR_GREATER
-                if ((Interlocked.Or(ref _state, StateBits.Closed) & StateBits.Closed) != StateBits.Closed) {
+                if ((InterlockedOrState(StateBits.Closed) & StateBits.Closed) != StateBits.Closed) {
                     // freshly closed, release the construction time reference
                     Release(exclusive: false);
                 }
-#else
-                int oldState, newState;
-                do {
-                    oldState = _state;
-                    newState = oldState | StateBits.Closed;
-                } while (Interlocked.CompareExchange(ref _state, newState, oldState) != oldState);
-                if ((oldState & StateBits.Closed) != StateBits.Closed) {
-                    // freshly closed, release the construction time reference
-                    Release(exclusive: false);
-                }
-#endif
             }
 
 
@@ -897,7 +910,6 @@ public string readline() {
                 }
             }
 
-
             public void resize(long newsize) {
                 using (new MmapLocker(this, exclusive: true)) {
                     if (_fileAccess is not MemoryMappedFileAccess.ReadWrite and not MemoryMappedFileAccess.ReadWriteExecute) {
@@ -931,13 +943,13 @@ public void resize(long newsize) {
                             int fd = unchecked((int)_handle.DangerousGetHandle());
                             PythonNT.ftruncateUnix(fd, newsize);
 
-    #if NET8_0_OR_GREATER
+#if NET8_0_OR_GREATER
                             _file = MemoryMappedFile.CreateFromFile(_handle, _mapName, newsize, _fileAccess, HandleInheritability.None, leaveOpen: true);
-    #else
+#else
                             _sourceStream?.Dispose();
                             _sourceStream = new FileStream(new SafeFileHandle((IntPtr)fd, ownsHandle: false), FileAccess.ReadWrite);
                             _file = CreateFromFile(_sourceStream, _mapName, newsize, _fileAccess, HandleInheritability.None, leaveOpen: true);
-    #endif
+#endif
                             _view = _file.CreateViewAccessor(_offset, newsize, _fileAccess);
                             return;
                         } catch {
@@ -1168,8 +1180,10 @@ private long Position {
                 }
             }
 
+            private bool IsReadOnly => _fileAccess is MemoryMappedFileAccess.Read or MemoryMappedFileAccess.ReadExecute;
+
             private void EnsureWritable() {
-                if (_fileAccess is MemoryMappedFileAccess.Read or MemoryMappedFileAccess.ReadExecute) {
+                if (IsReadOnly) {
                     throw PythonOps.TypeError("mmap can't modify a read-only memory map.");
                 }
             }
@@ -1299,8 +1313,77 @@ void IWeakReferenceable.SetFinalizer(WeakRefTracker value) {
             }
 
             #endregion
-        }
 
+            public IPythonBuffer GetBuffer(BufferFlags flags = BufferFlags.Simple) {
+                if (flags.HasFlag(BufferFlags.Writable) && IsReadOnly)
+                    throw PythonOps.BufferError("Object is not writable.");
+
+                return new MmapBuffer(this, flags);
+            }
+
+            private sealed unsafe class MmapBuffer : IPythonBuffer {
+                private readonly MmapDefault _mmap;
+                private readonly MmapLocker _locker;
+                private readonly BufferFlags _flags;
+                private SafeMemoryMappedViewHandle? _handle;
+                private byte* _pointer = null;
+
+                public MmapBuffer(MmapDefault mmap, BufferFlags flags) {
+                    _mmap = mmap;
+                    _flags = flags;
+                    _locker = new MmapLocker(mmap);
+                    mmap.InterlockedOrState(StateBits.Exporting);
+                    _handle = _mmap._view.SafeMemoryMappedViewHandle;
+                    ItemCount = _mmap.__len__() is int i ? i : throw new NotImplementedException();
+                }
+
+                public object Object => _mmap;
+
+                public bool IsReadOnly => _mmap.IsReadOnly;
+
+                public int Offset => 0;
+
+                public string? Format => _flags.HasFlag(BufferFlags.Format) ? "B" : null;
+
+                public int ItemCount { get; }
+
+                public int ItemSize => 1;
+
+                public int NumOfDims => 1;
+
+                public IReadOnlyList<int>? Shape => null;
+
+                public IReadOnlyList<int>? Strides => null;
+
+                public IReadOnlyList<int>? SubOffsets => null;
+
+                public unsafe ReadOnlySpan<byte> AsReadOnlySpan() {
+                    if (_handle is null) throw new ObjectDisposedException(nameof(MmapBuffer));
+                    if (_pointer is null) _handle.AcquirePointer(ref _pointer);
+                    return new ReadOnlySpan<byte>(_pointer, ItemCount);
+                }
+
+                public unsafe Span<byte> AsSpan() {
+                    if (_handle is null) throw new ObjectDisposedException(nameof(MmapBuffer));
+                    if (IsReadOnly) throw new InvalidOperationException("object is not writable");
+                    if (_pointer is null) _handle.AcquirePointer(ref _pointer);
+                    return new Span<byte>(_pointer, ItemCount);
+                }
+
+                public unsafe MemoryHandle Pin() {
+                    if (_handle is null) throw new ObjectDisposedException(nameof(MmapBuffer));
+                    if (_pointer is null) _handle.AcquirePointer(ref _pointer);
+                    return new MemoryHandle(_pointer);
+                }
+
+                public void Dispose() {
+                    var handle = Interlocked.Exchange(ref _handle, null);
+                    if (handle is null) return;
+                    if (_pointer is not null) handle.ReleasePointer();
+                    _locker.Dispose();
+                }
+            }
+        }
 
         #region P/Invoke for allocation granularity
 
diff --git a/src/core/IronPython.Modules/re.cs b/src/core/IronPython.Modules/re.cs
index c3df0eb71..bcc9e8baf 100644
--- a/src/core/IronPython.Modules/re.cs
+++ b/src/core/IronPython.Modules/re.cs
@@ -516,11 +516,6 @@ private string ValidateString(object? @string) {
                         case IList<byte> b:
                             str = b.MakeString();
                             break;
-#if FEATURE_MMAP
-                        case MmapModule.MmapDefault mmapFile:
-                            str = mmapFile.GetSearchString().MakeString();
-                            break;
-#endif
                         case string _:
                         case ExtensibleString _:
                             throw PythonOps.TypeError("cannot use a bytes pattern on a string-like object");
@@ -537,9 +532,6 @@ private string ValidateString(object? @string) {
                             break;
                         case IBufferProtocol _:
                         case IList<byte> _:
-#if FEATURE_MMAP
-                        case MmapModule.MmapDefault _:
-#endif
                             throw PythonOps.TypeError("cannot use a string pattern on a bytes-like object");
                         default:
                             throw PythonOps.TypeError("expected string or bytes-like object");