Skip to content

Commit

Permalink
Fix raceconditions for SimpleWebsocketCollection
Browse files Browse the repository at this point in the history
  • Loading branch information
hhvrc committed Feb 3, 2025
1 parent 8179e75 commit d525549
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 60 deletions.
82 changes: 30 additions & 52 deletions Common/Websocket/SimpleWebsocketCollection.cs
Original file line number Diff line number Diff line change
@@ -1,77 +1,55 @@
using System.Collections.Concurrent;
using OpenShock.Common.Extensions;

namespace OpenShock.Common.Websocket;

public sealed class SimpleWebsocketCollection<T, TR> where T : class, IWebsocketController<TR>
{
private readonly ConcurrentDictionary<Guid, List<T>> _websockets = new();
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1);

Check warning on line 7 in Common/Websocket/SimpleWebsocketCollection.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Check warning on line 7 in Common/Websocket/SimpleWebsocketCollection.cs

View workflow job for this annotation

GitHub Actions / build

Check warning on line 7 in Common/Websocket/SimpleWebsocketCollection.cs

View workflow job for this annotation

GitHub Actions / build

Check warning on line 7 in Common/Websocket/SimpleWebsocketCollection.cs

View workflow job for this annotation

GitHub Actions / build

private readonly Dictionary<Guid, List<T>> _websockets = [];

public void RegisterConnection(T controller)
public async Task RegisterConnection(T controller)
{
var list = _websockets.GetOrAdd(controller.Id,
new List<T> { controller });
lock (list)
using (await _semaphore.LockAsyncScoped())
{
if (!list.Contains(controller)) list.Add(controller);
if (!_websockets.TryGetValue(controller.Id, out var list))
{
list = [controller];
_websockets.Add(controller.Id, list);
}

list.Add(controller);
}
}

public void UnregisterConnection(T controller)
public async Task<bool> UnregisterConnection(T controller)
{
var key = controller.Id;
if (!_websockets.TryGetValue(key, out var list)) return;

lock (list)
using (await _semaphore.LockAsyncScoped())
{
list.Remove(controller);
if (list.Count <= 0) _websockets.TryRemove(key, out _);
if (!_websockets.TryGetValue(controller.Id, out var list)) return false;
if (!list.Remove(controller)) return false;
if (list.Count == 0)
{
_websockets.Remove(controller.Id);
}
}
}

public bool IsConnected(Guid id) => _websockets.ContainsKey(id);

public IList<T> GetConnections(Guid id)
{
if (_websockets.TryGetValue(id, out var list))
return list;
return Array.Empty<T>();
return true;
}

public async ValueTask SendMessageTo(Guid id, TR msg)
public async Task<T[]> GetConnections(Guid id)
{
var list = GetConnections(id);

// ReSharper disable once ForCanBeConvertedToForeach
for (var i = 0; i < list.Count; i++)
using (await _semaphore.LockAsyncScoped())
{
var conn = list[i];
await conn.QueueMessage(msg);
if (!_websockets.TryGetValue(id, out var list)) return [];
return list.ToArray();
}
}

public Task SendMessageTo(TR msg, params Guid[] id) => SendMessageTo(id, msg);

public Task SendMessageTo(IEnumerable<Guid> id, TR msg)
{
var tasks = id.Select(x => SendMessageTo(x, msg).AsTask());
return Task.WhenAll(tasks);
}

public async ValueTask SendMessageToAll(TR msg)
{
// Im cloning a moment-in-time snapshot on purpose here, so we dont miss any connections.
// This is fine since this is not regularly called, and does not need to be realtime.
foreach (var (_, list) in _websockets.ToArray())
foreach (var websocketController in list)
await websocketController.QueueMessage(msg);
}

public IEnumerable<T> GetConnectedById(IEnumerable<Guid> ids)
public async Task<uint> GetCount()
{
var found = new List<T>();
foreach (var id in ids) found.AddRange(GetConnections(id));
return found;
using (await _semaphore.LockAsyncScoped())
{
return (uint)_websockets.Sum(x => x.Value.Count);
}
}

public uint Count => (uint)_websockets.Sum(x => x.Value.Count);
}
6 changes: 3 additions & 3 deletions LiveControlGateway/Controllers/LiveControlController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ public LiveControlController(
}

/// <inheritdoc />
protected override Task<bool> TryRegisterConnection()
protected override async Task<bool> TryRegisterConnection()
{
WebsocketManager.LiveControlUsers.RegisterConnection(this);
return Task.FromResult(true);
await WebsocketManager.LiveControlUsers.RegisterConnection(this);
return true;
}

/// <inheritdoc />
Expand Down
3 changes: 1 addition & 2 deletions LiveControlGateway/LifetimeManager/HubLifetime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ public async Task UpdateDevice()
await using var db = await _dbContextFactory.CreateDbContextAsync(_cancellationSource.Token);
await UpdateShockers(db, _cancellationSource.Token);

foreach (var websocketController in
WebsocketManager.LiveControlUsers.GetConnections(HubController.Id))
foreach (var websocketController in await WebsocketManager.LiveControlUsers.GetConnections(HubController.Id))
await websocketController.UpdatePermissions(db);
}

Expand Down
5 changes: 2 additions & 3 deletions LiveControlGateway/LifetimeManager/HubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ public async Task<bool> TryAddDeviceConnection(byte tps, IHubController hubContr
return false;
}

foreach (var websocketController in WebsocketManager.LiveControlUsers.GetConnections(hubLifetime
.HubController.Id))
foreach (var websocketController in await WebsocketManager.LiveControlUsers.GetConnections(hubLifetime.HubController.Id))
await websocketController.UpdateConnectedState(true);
}

Expand Down Expand Up @@ -155,7 +154,7 @@ public async Task RemoveDeviceConnection(IHubController hubController, bool noti

if (notifyLiveControlClients)
{
foreach (var websocketController in WebsocketManager.LiveControlUsers.GetConnections(hubController.Id))
foreach (var websocketController in await WebsocketManager.LiveControlUsers.GetConnections(hubController.Id))
await websocketController.UpdateConnectedState(false);
}

Expand Down

0 comments on commit d525549

Please sign in to comment.