-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix raceconditions for SimpleWebsocketCollection
- Loading branch information
Showing
4 changed files
with
36 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,77 +1,55 @@ | ||
using System.Collections.Concurrent; | ||
using OpenShock.Common.Extensions; | ||
|
||
namespace OpenShock.Common.Websocket; | ||
|
||
public sealed class SimpleWebsocketCollection<T, TR> where T : class, IWebsocketController<TR> | ||
{ | ||
private readonly ConcurrentDictionary<Guid, List<T>> _websockets = new(); | ||
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1); | ||
Check warning on line 7 in Common/Websocket/SimpleWebsocketCollection.cs GitHub Actions / Analyze (csharp)
Check warning on line 7 in Common/Websocket/SimpleWebsocketCollection.cs GitHub Actions / build
Check warning on line 7 in Common/Websocket/SimpleWebsocketCollection.cs GitHub Actions / build
Check warning on line 7 in Common/Websocket/SimpleWebsocketCollection.cs GitHub Actions / build
|
||
private readonly Dictionary<Guid, List<T>> _websockets = []; | ||
|
||
public void RegisterConnection(T controller) | ||
public async Task RegisterConnection(T controller) | ||
{ | ||
var list = _websockets.GetOrAdd(controller.Id, | ||
new List<T> { controller }); | ||
lock (list) | ||
using (await _semaphore.LockAsyncScoped()) | ||
{ | ||
if (!list.Contains(controller)) list.Add(controller); | ||
if (!_websockets.TryGetValue(controller.Id, out var list)) | ||
{ | ||
list = [controller]; | ||
_websockets.Add(controller.Id, list); | ||
} | ||
|
||
list.Add(controller); | ||
} | ||
} | ||
|
||
public void UnregisterConnection(T controller) | ||
public async Task<bool> UnregisterConnection(T controller) | ||
{ | ||
var key = controller.Id; | ||
if (!_websockets.TryGetValue(key, out var list)) return; | ||
|
||
lock (list) | ||
using (await _semaphore.LockAsyncScoped()) | ||
{ | ||
list.Remove(controller); | ||
if (list.Count <= 0) _websockets.TryRemove(key, out _); | ||
if (!_websockets.TryGetValue(controller.Id, out var list)) return false; | ||
if (!list.Remove(controller)) return false; | ||
if (list.Count == 0) | ||
{ | ||
_websockets.Remove(controller.Id); | ||
} | ||
} | ||
} | ||
|
||
public bool IsConnected(Guid id) => _websockets.ContainsKey(id); | ||
|
||
public IList<T> GetConnections(Guid id) | ||
{ | ||
if (_websockets.TryGetValue(id, out var list)) | ||
return list; | ||
return Array.Empty<T>(); | ||
return true; | ||
} | ||
|
||
public async ValueTask SendMessageTo(Guid id, TR msg) | ||
public async Task<T[]> GetConnections(Guid id) | ||
{ | ||
var list = GetConnections(id); | ||
|
||
// ReSharper disable once ForCanBeConvertedToForeach | ||
for (var i = 0; i < list.Count; i++) | ||
using (await _semaphore.LockAsyncScoped()) | ||
{ | ||
var conn = list[i]; | ||
await conn.QueueMessage(msg); | ||
if (!_websockets.TryGetValue(id, out var list)) return []; | ||
return list.ToArray(); | ||
} | ||
} | ||
|
||
public Task SendMessageTo(TR msg, params Guid[] id) => SendMessageTo(id, msg); | ||
|
||
public Task SendMessageTo(IEnumerable<Guid> id, TR msg) | ||
{ | ||
var tasks = id.Select(x => SendMessageTo(x, msg).AsTask()); | ||
return Task.WhenAll(tasks); | ||
} | ||
|
||
public async ValueTask SendMessageToAll(TR msg) | ||
{ | ||
// Im cloning a moment-in-time snapshot on purpose here, so we dont miss any connections. | ||
// This is fine since this is not regularly called, and does not need to be realtime. | ||
foreach (var (_, list) in _websockets.ToArray()) | ||
foreach (var websocketController in list) | ||
await websocketController.QueueMessage(msg); | ||
} | ||
|
||
public IEnumerable<T> GetConnectedById(IEnumerable<Guid> ids) | ||
public async Task<uint> GetCount() | ||
{ | ||
var found = new List<T>(); | ||
foreach (var id in ids) found.AddRange(GetConnections(id)); | ||
return found; | ||
using (await _semaphore.LockAsyncScoped()) | ||
{ | ||
return (uint)_websockets.Sum(x => x.Value.Count); | ||
} | ||
} | ||
|
||
public uint Count => (uint)_websockets.Sum(x => x.Value.Count); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters