Skip to content

Commit

Permalink
#335: Private threads hotfix.
Browse files Browse the repository at this point in the history
#!components: grid-bot

ExecuteScript.cs:
~ Take reference to current DiscordShardedClient.
~ Move correspondence calls when errors are called to use LuaError instead of regular follow ups.
~ Add detection of scripts containing code blocks within zero content (causes exception)
~ Rewrite error for scripts that contain unicode to reduce ambiguity.
~ Rename GridJob to ClientJob.
~ Clean up call to PollDeletion
~ Integrate use of GetChannelAsString and GetGuild.

OnSlashCommand, OnSlashCommandExecuted, LoggerFactory, ScriptLogger:
~ Clean up usings.
~ Rewrite GetGuildId to use extension methods.
~ Change around references to channel to use extension methods.
  • Loading branch information
jf-06 committed Sep 17, 2024
1 parent 6258bf0 commit 2e25e54
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 44 deletions.
32 changes: 21 additions & 11 deletions services/grid-bot/lib/commands/Modules/ExecuteScript.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace Grid.Bot.Interactions.Public;
using System.Text.RegularExpressions;

using Discord;
using Discord.WebSocket;
using Discord.Interactions;

using Loretta.CodeAnalysis;
Expand All @@ -25,7 +26,7 @@ namespace Grid.Bot.Interactions.Public;
using Commands;
using Extensions;

using GridJob = Client.Job;
using ClientJob = Client.Job;

/// <summary>
/// Interaction handler for executing Luau code.
Expand All @@ -36,6 +37,7 @@ namespace Grid.Bot.Interactions.Public;
/// <param name="logger">The <see cref="ILogger"/>.</param>
/// <param name="gridSettings">The <see cref="GridSettings"/>.</param>
/// <param name="scriptsSettings">The <see cref="ScriptsSettings"/>.</param>
/// <param name="discordClient">The <see cref="DiscordShardedClient"/>.</param>
/// <param name="luaUtility">The <see cref="ILuaUtility"/>.</param>
/// <param name="floodCheckerRegistry">The <see cref="IFloodCheckerRegistry"/>.</param>
/// <param name="backtraceUtility">The <see cref="IBacktraceUtility"/>.</param>
Expand All @@ -48,6 +50,7 @@ namespace Grid.Bot.Interactions.Public;
/// - <paramref name="logger"/> cannot be null.
/// - <paramref name="gridSettings"/> cannot be null.
/// - <paramref name="scriptsSettings"/> cannot be null.
/// - <paramref name="discordClient"/> cannot be null.
/// - <paramref name="luaUtility"/> cannot be null.
/// - <paramref name="floodCheckerRegistry"/> cannot be null.
/// - <paramref name="backtraceUtility"/> cannot be null.
Expand All @@ -64,6 +67,7 @@ public partial class ExecuteScript(
ILogger logger,
GridSettings gridSettings,
ScriptsSettings scriptsSettings,
DiscordShardedClient discordClient,
ILuaUtility luaUtility,
IFloodCheckerRegistry floodCheckerRegistry,
IBacktraceUtility backtraceUtility,
Expand All @@ -83,6 +87,7 @@ IGridServerFileHelper gridServerFileHelper
private readonly GridSettings _gridSettings = gridSettings ?? throw new ArgumentNullException(nameof(gridSettings));
private readonly ScriptsSettings _scriptsSettings = scriptsSettings ?? throw new ArgumentNullException(nameof(scriptsSettings));

private readonly DiscordShardedClient _discordClient = discordClient ?? throw new ArgumentNullException(nameof(discordClient));
private readonly ILuaUtility _luaUtility = luaUtility ?? throw new ArgumentNullException(nameof(luaUtility));
private readonly IFloodCheckerRegistry _floodCheckerRegistry = floodCheckerRegistry ?? throw new ArgumentNullException(nameof(floodCheckerRegistry));
private readonly IBacktraceUtility _backtraceUtility = backtraceUtility ?? throw new ArgumentNullException(nameof(backtraceUtility));
Expand Down Expand Up @@ -304,12 +309,20 @@ string script
{
if (string.IsNullOrWhiteSpace(script))
{
await FollowupAsync("The script cannot be empty.");
await LuaErrorAsync("The script cannot be empty!");

return;
}

script = GetCodeBlockContents(script);

if (string.IsNullOrEmpty(script))
{
await LuaErrorAsync("There must be content within a code block!");

return;
}

script = EscapeQuotes(script);

var originalScript = script;
Expand All @@ -318,7 +331,7 @@ string script

if (ContainsUnicode(script))
{
await FollowupAsync("The script cannot contain unicode characters as grid-servers cannot support unicode in transit.");
await LuaErrorAsync("Scripts can only contain ASCII characters!");

return;
}
Expand Down Expand Up @@ -352,7 +365,7 @@ string script
#endif


var gridJob = new GridJob() { id = scriptId, expirationInSeconds = _gridSettings.ScriptExecutionJobMaxTimeout.TotalSeconds };
var gridJob = new ClientJob() { id = scriptId, expirationInSeconds = _gridSettings.ScriptExecutionJobMaxTimeout.TotalSeconds };
var job = new Job(Guid.NewGuid().ToString());

var sw = Stopwatch.StartNew();
Expand Down Expand Up @@ -461,9 +474,8 @@ string script
scriptName
);
scriptName.PollDeletion(
10,
ex => _logger.Warning("Failed to delete '{0}' because: {1}", scriptName, ex.Message),
() => _logger.Debug(
onFailure: ex => _logger.Warning("Failed to delete '{0}' because: {1}", scriptName, ex.Message),
onSuccess: () => _logger.Debug(
"Successfully deleted the script '{0}' at path '{1}'!",
scriptId,
scriptName
Expand Down Expand Up @@ -493,10 +505,8 @@ private async Task AlertForSystem(string script, string originalScript, string s
_backtraceUtility.UploadException(ex);

var userInfo = Context.User.ToString();
var guildInfo = Context.Guild?.ToString() ?? "DMs";

/* Temporary until mfdlabs/grid-bot#335 is resolved */
var channelInfo = Context.Channel?.ToString() ?? Context.Interaction.ChannelId?.ToString() ?? "Thread";
var guildInfo = Context.Interaction.GetGuild(_discordClient)?.ToString() ?? "DMs";
var channelInfo = Context.Interaction.GetChannelAsString();

// Script & original script in attachments
var scriptAttachment = new FileAttachment(new MemoryStream(Encoding.ASCII.GetBytes(script)), "script.lua");
Expand Down
19 changes: 6 additions & 13 deletions services/grid-bot/lib/events/Events/OnSlashCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
using System.Threading.Tasks;
using System.Collections.Generic;

using Discord;
using Discord.WebSocket;
using Discord.Interactions;

using Prometheus;

using Utility;
using Discord;
using Extensions;

/// <summary>
/// Event handler for interactions.
Expand Down Expand Up @@ -87,14 +88,8 @@ ILoggerFactory loggerFactory
}
);

private static string GetGuildId(SocketInteraction interaction)
{
/* Always false in private thread channels, please look into discord-net/Discord.Net#2997 and mfdlabs/grid-bot#335 */
if (interaction.Channel is SocketGuildChannel guildChannel)
return guildChannel.Guild.Id.ToString();

return "DM";
}
private string GetGuildId(SocketInteraction interaction)
=> interaction.GetGuild(_client).ToString() ?? "DM";

/// <summary>
/// Invoke the event handler.
Expand Down Expand Up @@ -149,8 +144,7 @@ public async Task Invoke(SocketInteraction interaction)

_totalUsersBypassedMaintenance.WithLabels(
interaction.User.Id.ToString(),
/* Temporary until mfdlabs/grid-bot#335 is resolved */
interaction.Channel?.Id.ToString() ?? interaction.ChannelId?.ToString() ?? "Thread",
interaction.GetChannelAsString(),
GetGuildId(interaction)
).Inc();
}
Expand All @@ -159,8 +153,7 @@ public async Task Invoke(SocketInteraction interaction)
{
_totalBlacklistedUserAttemptedInteractions.WithLabels(
interaction.User.Id.ToString(),
/* Temporary until mfdlabs/grid-bot#335 is resolved */
interaction.Channel?.Id.ToString() ?? interaction.ChannelId?.ToString() ?? "Thread",
interaction.GetChannelAsString(),
GetGuildId(interaction)
).Inc();

Expand Down
20 changes: 14 additions & 6 deletions services/grid-bot/lib/events/Events/OnSlashCommandExecuted.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ namespace Grid.Bot.Events;
using System.Threading.Tasks;

using Discord;
using Discord.WebSocket;
using Discord.Interactions;

using Prometheus;

using Logging;

using Utility;
using Extensions;


/// <summary>
/// Invoked when slash commands are executed.
Expand All @@ -23,15 +26,18 @@ namespace Grid.Bot.Events;
/// <param name="logger">The <see cref="ILogger"/>.</param>
/// <param name="backtraceUtility">The <see cref="BacktraceUtility"/>.</param>
/// <param name="discordRolesSettings">The <see cref="DiscordRolesSettings"/>.</param>
/// <param name="discordClient">The <see cref="DiscordShardedClient"/>.</param>
/// <exception cref="ArgumentNullException">
/// - <paramref name="logger"/> cannot be null.
/// - <paramref name="backtraceUtility"/> cannot be null.
/// - <paramref name="discordRolesSettings"/> cannot be null.
/// - <paramref name="discordClient"/> cannot be null.
/// </exception>
public class OnInteractionExecuted(
ILogger logger,
IBacktraceUtility backtraceUtility,
DiscordRolesSettings discordRolesSettings
DiscordRolesSettings discordRolesSettings,
DiscordShardedClient discordClient
)
{
private const string UnhandledExceptionOccurredFromCommand = "An error occured with the command:";
Expand All @@ -40,6 +46,7 @@ DiscordRolesSettings discordRolesSettings

private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger));
private readonly IBacktraceUtility _backtraceUtility = backtraceUtility ?? throw new ArgumentNullException(nameof(backtraceUtility));
private readonly DiscordShardedClient _discordClient = discordClient ?? throw new ArgumentNullException(nameof(discordClient));

private readonly Counter _totalInteractionsFailed = Metrics.CreateCounter(
"grid_interactions_failed_total",
Expand All @@ -51,9 +58,9 @@ DiscordRolesSettings discordRolesSettings
"interaction_guild_id"
);

private string GetGuildId(IInteractionContext context)
private string GetGuildId(SocketInteraction interaction)
{
return context.Guild?.Id.ToString() ?? "DM";
return interaction.GetGuild(_discordClient)?.Id.ToString() ?? "DM";
}

/// <summary>
Expand All @@ -64,7 +71,8 @@ private string GetGuildId(IInteractionContext context)
/// <param name="result">The <see cref="IResult"/>.</param>
public async Task Invoke(ICommandInfo command, IInteractionContext context, IResult result)
{
var interaction = context.Interaction;
if (context.Interaction is not SocketInteraction interaction)
return;

if (!result.IsSuccess)
{
Expand All @@ -75,8 +83,8 @@ public async Task Invoke(ICommandInfo command, IInteractionContext context, IRes
interaction.Id.ToString(),
interaction.User.Id.ToString(),
/* Temporary until mfdlabs/grid-bot#335 is resolved */
interaction.ChannelId?.ToString() ?? "Thread",
GetGuildId(context)
interaction.GetChannelAsString(),
GetGuildId(interaction)
).Inc();

if (result is not ExecuteResult executeResult)
Expand Down
27 changes: 18 additions & 9 deletions services/grid-bot/lib/utility/Implementation/LoggerFactory.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
namespace Grid.Bot.Utility;

using System;

using Discord.WebSocket;

using Logging;

using Extensions;

/// <summary>
/// Implementation of <see cref="ILoggerFactory"/>.
/// </summary>
/// <param name="discordClient">The <see cref="DiscordShardedClient"/>.</param>
/// <exception cref="ArgumentNullException"><paramref name="discordClient"/> cannot be null.</exception>
/// <seealso cref="ILoggerFactory"/>
/// <seealso cref="ILogger"/>
public class LoggerFactory : ILoggerFactory
public class LoggerFactory(DiscordShardedClient discordClient) : ILoggerFactory
{
private readonly DiscordShardedClient _discordClient = discordClient ?? throw new ArgumentNullException(nameof(discordClient));

/// <inheritdoc cref="ILoggerFactory.CreateLogger(SocketInteraction)"/>
public ILogger CreateLogger(SocketInteraction interaction)
{
Expand All @@ -21,12 +29,14 @@ public ILogger CreateLogger(SocketInteraction interaction)
logToFileSystem: false
);

logger.CustomLogPrefixes.Add(() => interaction.ChannelId?.ToString() ?? "Thread");
logger.CustomLogPrefixes.Add(() => interaction.User.Id.ToString());
logger.CustomLogPrefixes.Add(() => interaction.GetChannelAsString());
logger.CustomLogPrefixes.Add(() => interaction.User.ToString());

var guild = interaction.GetGuild(_discordClient);

// Add guild id if the interaction is from a guild.
if (interaction.Channel is SocketGuildChannel guildChannel)
logger.CustomLogPrefixes.Add(() => guildChannel.Guild.Id.ToString());
if (guild is not null)
logger.CustomLogPrefixes.Add(() => guild.ToString());

return logger;
}
Expand All @@ -41,13 +51,12 @@ public ILogger CreateLogger(SocketMessage message)
logToFileSystem: false
);

logger.CustomLogPrefixes.Add(() => message.Channel.Id.ToString());
logger.CustomLogPrefixes.Add(() => message.Author.Id.ToString());
logger.CustomLogPrefixes.Add(() => message.Channel.ToString());
logger.CustomLogPrefixes.Add(() => message.Author.ToString());

// Add guild id if the message is from a guild.
/* Always false in private thread channels, please look into discord-net/Discord.Net#2997 and mfdlabs/grid-bot#335 */
if (message.Channel is SocketGuildChannel guildChannel)
logger.CustomLogPrefixes.Add(() => guildChannel.Guild.Id.ToString());
logger.CustomLogPrefixes.Add(() => guildChannel.Guild.ToString());

return logger;
}
Expand Down
17 changes: 12 additions & 5 deletions services/grid-bot/lib/utility/Implementation/ScriptLogger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ namespace Grid.Bot.Utility;
using System.Collections.Concurrent;

using Discord;
using Discord.WebSocket;
using Discord.Interactions;

using Newtonsoft.Json;

using Random;
using Networking;

using Extensions;


/// <summary>
/// Handles sending alerts to a Discord webhook.
/// </summary>
Expand All @@ -26,6 +30,7 @@ public class ScriptLogger : IScriptLogger
private readonly IPercentageInvoker _percentageInvoker;
private readonly IHttpClientFactory _httpClientFactory;
private readonly ScriptsSettings _scriptsSettings;
private readonly DiscordShardedClient _discordClient;

private readonly ConcurrentBag<string> _scriptHashes = new();

Expand All @@ -36,23 +41,27 @@ public class ScriptLogger : IScriptLogger
/// <param name="percentageInvoker">The <see cref="IPercentageInvoker"/> to use.</param>
/// <param name="httpClientFactory">The <see cref="IHttpClientFactory"/> to use.</param>
/// <param name="scriptsSettings">The <see cref="ScriptsSettings"/> to use.</param>
/// <param name="discordClient">The <see cref="DiscordShardedClient"/> to use.</param>
/// <exception cref="ArgumentNullException">
/// - <paramref name="localIpAddressProvider"/> cannot be null.
/// - <paramref name="percentageInvoker"/> cannot be null.
/// - <paramref name="httpClientFactory"/> cannot be null.
/// - <paramref name="scriptsSettings"/> cannot be null.
/// - <paramref name="discordClient"/> cannot be null.
/// </exception>
public ScriptLogger(
ILocalIpAddressProvider localIpAddressProvider,
IPercentageInvoker percentageInvoker,
IHttpClientFactory httpClientFactory,
ScriptsSettings scriptsSettings
ScriptsSettings scriptsSettings,
DiscordShardedClient discordClient
)
{
_localIpAddressProvider = localIpAddressProvider ?? throw new ArgumentNullException(nameof(localIpAddressProvider));
_percentageInvoker = percentageInvoker ?? throw new ArgumentNullException(nameof(percentageInvoker));
_httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
_scriptsSettings = scriptsSettings ?? throw new ArgumentNullException(nameof(scriptsSettings));
_discordClient = discordClient ?? throw new ArgumentNullException(nameof(discordClient));

foreach (var hash in _scriptsSettings.LoggedScriptHashes)
_scriptHashes.Add(hash);
Expand Down Expand Up @@ -83,10 +92,8 @@ public async Task LogScriptAsync(string script, ShardedInteractionContext contex
// username based off machine info
var username = $"{Environment.MachineName} ({_localIpAddressProvider.AddressV4} / {_localIpAddressProvider.AddressV6})";
var userInfo = context.User.ToString();
var guildInfo = context.Guild?.ToString() ?? "DMs";

/* Temporary until mfdlabs/grid-bot#335 is resolved */
var channelInfo = context.Channel?.ToString() ?? context.Interaction.ChannelId?.ToString() ?? "Thread";
var guildInfo = context.Interaction.GetGuild(_discordClient)?.ToString() ?? "DMs";
var channelInfo = context.Interaction.GetChannelAsString();

// Get a SHA256 hash of the script (hex)
var scriptHash = string.Join("", SHA256.HashData(Encoding.UTF8.GetBytes(script)).Select(b => b.ToString("x2")));
Expand Down

0 comments on commit 2e25e54

Please sign in to comment.