Skip to content

Commit

Permalink
Move clean up logic into PackageDownloader
Browse files Browse the repository at this point in the history
  • Loading branch information
joelverhagen committed Dec 21, 2023
1 parent 0fd4126 commit 1f95872
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 102 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
$notMatching = $report | ? { !$_.StartsWith("Matching") }
Write-Output "versions-changed=$(($notMatching.Length -ne 0).ToString().ToLowerInvariant())" >> $env:GITHUB_OUTPUT
format:
check-formatting:
runs-on: windows-latest
defaults:
run:
Expand Down
16 changes: 1 addition & 15 deletions Invoke-DownloadPackages.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,7 @@ if ($Force -and (Test-Path $PackagesDir)) {
$toolDir = Join-Path $PSScriptRoot "src/Knapcode.NuGetTools.PackageDownloader"
& dotnet run --project $toolDir --configuration Release -- download $PackagesDir
if ($LASTEXITCODE -ne 0) {
throw "Package downloader failed with exit code $LastExitCode."
throw "Package downloader failed with exit code $LASTEXITCODE."
}

Write-Host "Successfully downloaded NuGet packages"

function Remove-ExtraFiles($pattern) {
Write-Host "Deleting $pattern files"
Get-ChildItem (Join-Path $PackagesDir $pattern) -Recurse | Remove-Item
}

# leave DLLs for loading at runtime and .sha512 for existence check
Remove-ExtraFiles "*.nupkg"
Remove-ExtraFiles "*.nuspec"
Remove-ExtraFiles "*.xml"
Remove-ExtraFiles "*.png"
Remove-ExtraFiles "*.md"
Remove-ExtraFiles ".signature.p7s"
Remove-ExtraFiles ".nupkg.metadata"
39 changes: 23 additions & 16 deletions src/Knapcode.NuGetTools.Logic.Direct/PackageRangeDownloader.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Collections.Concurrent;
using NuGet.Common;
using NuGet.PackageManagement;
using NuGet.Packaging;
using NuGet.Packaging.Core;
using NuGet.Packaging.Signing;
using NuGet.Protocol;
Expand Down Expand Up @@ -127,31 +128,37 @@ private async Task DownloadPackageAsync(
ILogger log,
CancellationToken token)
{
var resolver = new VersionFolderPathResolver(_nuGetSettings.GlobalPackagesFolder);
var hashPath = resolver.GetHashPath(identity.Id, identity.Version);
if (File.Exists(hashPath))
{
log.LogInformation($"The package '{identity}' is already available.");
return;
}

var packageDownloadContext = new PackageDownloadContext(sourceCacheContext);
var result = await PackageDownloader.GetDownloadResourceResultAsync(

using var downloadResult = await PackageDownloader.GetDownloadResourceResultAsync(
sourceRepositories,
packageIdentity: identity,
downloadContext: packageDownloadContext,
globalPackagesFolder: _nuGetSettings.GlobalPackagesFolder,
logger: log,
token: token);

using (result)
if (downloadResult.Status != DownloadResourceResultStatus.Available)
{
if (result.Status != DownloadResourceResultStatus.Available)
{
throw new InvalidOperationException($"The package '{identity}' is not available.");
}

await GlobalPackagesFolderUtility.AddPackageAsync(
result.PackageSource,
identity,
result.PackageStream,
_nuGetSettings.GlobalPackagesFolder,
packageDownloadContext.ParentId,
ClientPolicyContext.GetClientPolicy(_nuGetSettings.Settings, log),
log,
token);
throw new InvalidOperationException($"The package '{identity}' is not available.");
}

using var addResult = await GlobalPackagesFolderUtility.AddPackageAsync(
downloadResult.PackageSource,
identity,
downloadResult.PackageStream,
_nuGetSettings.GlobalPackagesFolder,
packageDownloadContext.ParentId,
ClientPolicyContext.GetClientPolicy(_nuGetSettings.Settings, log),
log,
token);
}
}
110 changes: 60 additions & 50 deletions src/Knapcode.NuGetTools.Logic.Direct/VersionedToolsFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ namespace Knapcode.NuGetTools.Logic.Direct;

public class VersionedToolsFactory : IToolsFactory
{
private static readonly NuGetFramework Net48 = NuGetFramework.Parse("net48");
private static readonly NuGetFramework Net6 = NuGetFramework.Parse("net6.0");
public static readonly IReadOnlyList<NuGetFramework> CompatibleFrameworks = new[]
{
NuGetFramework.Parse("net48"),
NuGetFramework.Parse("net8.0"),
};

private static readonly Lazy<ModuleDefinition> NuGetLogic2xModule
= new Lazy<ModuleDefinition>(() => ModuleDefinition.ReadModule(typeof(NuGetLogic2x).Assembly.Location));
Expand All @@ -32,10 +35,10 @@ private static readonly Lazy<ModuleDefinition> NuGetLogic3xModule
private readonly IAlignedVersionsDownloader _downloader;
private readonly IFrameworkList _frameworkList;
private readonly NuGetSettings _settings;
private readonly Lazy<Task<Dictionary<string, NuGetVersion>>> _versions;
private readonly Lazy<Task<Dictionary<string, NuGetVersion>>> _stringToVersion;
private readonly Lazy<Task<List<string>>> _versionStrings;
private readonly Lazy<Task<string>> _latestVersion;
private readonly Lazy<Task<Dictionary<NuGetVersion, NuGetRelease>>> _releases;
private readonly Lazy<Task<IReadOnlySet<NuGetVersion>>> _versionSet;

private readonly ConcurrentDictionary<NuGetVersion, Task<IToolsService>> _toolServices = new();
private readonly ConcurrentDictionary<NuGetVersion, Task<IFrameworkPrecedenceService>> _frameworkPrecedenceServices = new();
Expand All @@ -52,7 +55,7 @@ public VersionedToolsFactory(
_settings = settings;
_nuGetLog = new MicrosoftLogger(log);

_releases = new Lazy<Task<Dictionary<NuGetVersion, NuGetRelease>>>(async () =>
_versionSet = new Lazy<Task<IReadOnlySet<NuGetVersion>>>(async () =>
{
using (var sourceCacheContext = new SourceCacheContext())
{
Expand All @@ -61,37 +64,31 @@ public VersionedToolsFactory(
sourceCacheContext,
_nuGetLog,
CancellationToken.None);
var pairs2x = versions2x
.Select(x => new KeyValuePair<NuGetVersion, NuGetRelease>(x, NuGetRelease.Version2x));

var versions3x = await _downloader.GetDownloadedVersionsAsync(
Constants.PackageIds3x,
sourceCacheContext,
_nuGetLog,
CancellationToken.None);
var pairs3x = versions3x
.Select(x => new KeyValuePair<NuGetVersion, NuGetRelease>(x, NuGetRelease.Version3x));

return pairs2x
.Concat(pairs3x)
.ToDictionary(x => x.Key, x => x.Value);
return versions2x.Concat(versions3x).ToHashSet();
}
});

_versions = new Lazy<Task<Dictionary<string, NuGetVersion>>>(async () =>
_stringToVersion = new Lazy<Task<Dictionary<string, NuGetVersion>>>(async () =>
{
var releases = await _releases.Value;
var releases = await _versionSet.Value;

return releases
.ToDictionary(
x => x.Key.ToNormalizedString(),
x => x.Key,
x => x.ToNormalizedString(),
x => x,
StringComparer.OrdinalIgnoreCase);
});

_versionStrings = new Lazy<Task<List<string>>>(async () =>
{
var versions = await _versions.Value;
var versions = await _stringToVersion.Value;

return versions
.OrderByDescending(x => x.Value)
Expand All @@ -101,7 +98,7 @@ public VersionedToolsFactory(

_latestVersion = new Lazy<Task<string>>(async () =>
{
var versions = await _versions.Value;
var versions = await _stringToVersion.Value;

return versions
.OrderByDescending(x => x.Value)
Expand Down Expand Up @@ -130,7 +127,7 @@ public async Task<IEnumerable<string>> GetAvailableVersionsAsync(CancellationTok
matchingVersion,
async key =>
{
var logic = await GetLogicAsync(key);
var logic = await InitializeAndGetLogicAsync(key);
return new ToolsService(version, logic);
});
}
Expand All @@ -156,7 +153,7 @@ public async Task<IEnumerable<string>> GetAvailableVersionsAsync(CancellationTok
matchingVersion,
async key =>
{
var logic = await GetLogicAsync(key);
var logic = await InitializeAndGetLogicAsync(key);
return new FrameworkPrecedenceService(
version,
_frameworkList,
Expand All @@ -182,7 +179,7 @@ public Task<string> GetLatestVersionAsync(CancellationToken token)

private async Task<NuGetVersion?> GetMatchingVersionAsync(string version)
{
var versions = await _versions.Value;
var versions = await _stringToVersion.Value;
NuGetVersion? matchedVersion;
if (!versions.TryGetValue(version, out matchedVersion))
{
Expand All @@ -192,48 +189,58 @@ public Task<string> GetLatestVersionAsync(CancellationToken token)
return matchedVersion;
}

private async Task<INuGetLogic> GetLogicAsync(NuGetVersion version)
private async Task<INuGetLogic> InitializeAndGetLogicAsync(NuGetVersion version)
{
var context = Contexts.GetOrAdd(version, _ => new VersionContext());
await _versionSet.Value;

var context = await GetContextAsync(_settings.GlobalPackagesFolder, version);

return context.Logic!;
}

public static async Task<IEnumerable<Assembly>> GetLoadedAssembliesAsync(string packagesFolder, NuGetVersion version)
{
var context = await GetContextAsync(packagesFolder, version);

return context.AssemblyLoadContext!.Assemblies;
}

private static async Task<VersionContext> GetContextAsync(string packagesFolder, NuGetVersion version)
{
var context = Contexts.GetOrAdd(version, _ => new VersionContext());
if (context.Logic is not null)
{
return context.Logic;
return context;
}

await context.Lock.WaitAsync();
try
{
if (context.Logic is not null)
{
return context.Logic;
return context;
}

return await InitializeContextAsync(version, context);
return InitializeContext(packagesFolder, version, context);
}
finally
{
context.Lock.Release();
}
}

private async Task<INuGetLogic> InitializeContextAsync(NuGetVersion version, VersionContext context)
private static VersionContext InitializeContext(string packagesFolder, NuGetVersion version, VersionContext context)
{
var releases = await _releases.Value;
NuGetRelease release;
if (!releases.TryGetValue(version, out release))
{
throw new ArgumentException($"The provided version '{version}' is not supported");
}
var release = version.Major >= 3 ? NuGetRelease.Version3x : NuGetRelease.Version2x;

var assemblyLoadContext = new AssemblyLoadContext(
name: $"NuGet {version.ToNormalizedString()}",
isCollectible: false);

var logicAssembly = release switch
{
NuGetRelease.Version2x => GetV2Implementation(version, assemblyLoadContext),
NuGetRelease.Version3x => GetV3Implementation(version, assemblyLoadContext),
NuGetRelease.Version2x => GetV2Implementation(packagesFolder, version, assemblyLoadContext),
NuGetRelease.Version3x => GetV3Implementation(packagesFolder, version, assemblyLoadContext),
_ => throw new NotImplementedException(),
};

Expand All @@ -244,31 +251,31 @@ private async Task<INuGetLogic> InitializeContextAsync(NuGetVersion version, Ver
context.Logic = (INuGetLogic)Activator.CreateInstance(logicType)!;
context.AssemblyLoadContext = assemblyLoadContext;

return context.Logic;
return context;
}

private Assembly GetV2Implementation(NuGetVersion version, AssemblyLoadContext context)
private static Assembly GetV2Implementation(string packagesFolder, NuGetVersion version, AssemblyLoadContext context)
{
var coreIdentity = new PackageIdentity(Constants.CoreId, version);
var assemblies = LoadPackageAssemblies(context, coreIdentity);
var assemblies = LoadPackageAssemblies(packagesFolder, coreIdentity, context);

var logicAssembly = RewriteProxyReferences(context, NuGetLogic2xModule, assemblies);
return logicAssembly;
}

private Assembly GetV3Implementation(NuGetVersion version, AssemblyLoadContext context)
private static Assembly GetV3Implementation(string packagesFolder, NuGetVersion version, AssemblyLoadContext context)
{
var versioningIdentity = new PackageIdentity(Constants.VersioningId, version);
var assemblies = LoadPackageAssemblies(context, versioningIdentity);
var assemblies = LoadPackageAssemblies(packagesFolder, versioningIdentity, context);

var frameworksIdentity = new PackageIdentity(Constants.FrameworksId, version);
assemblies.AddRange(LoadPackageAssemblies(context, frameworksIdentity));
assemblies.AddRange(LoadPackageAssemblies(packagesFolder, frameworksIdentity, context));

var logicAssembly = RewriteProxyReferences(context, NuGetLogic3xModule, assemblies);
return logicAssembly;
}

private Assembly RewriteProxyReferences(
private static Assembly RewriteProxyReferences(
AssemblyLoadContext context,
Lazy<ModuleDefinition> lazyBaseModule,
List<Assembly> newReferences)
Expand Down Expand Up @@ -297,11 +304,9 @@ private Assembly RewriteProxyReferences(
return context.LoadFromStream(moduleStream);
}

private List<Assembly> LoadPackageAssemblies(
AssemblyLoadContext context,
PackageIdentity packageIdentity)
private static List<Assembly> LoadPackageAssemblies(string packagesFolder, PackageIdentity packageIdentity, AssemblyLoadContext context)
{
var pathResolver = new VersionFolderPathResolver(_settings.GlobalPackagesFolder);
var pathResolver = new VersionFolderPathResolver(packagesFolder);
var hashPath = pathResolver.GetHashPath(packageIdentity.Id, packageIdentity.Version);

if (!File.Exists(hashPath))
Expand All @@ -313,13 +318,18 @@ private List<Assembly> LoadPackageAssemblies(

using (var packageReader = new PackageFolderReader(installPath))
{
if (!TryLoadWithFramework(context, Net6, installPath, packageReader, out var assemblies)
&& !TryLoadWithFramework(context, Net48, installPath, packageReader, out assemblies))
foreach (var framework in CompatibleFrameworks)
{
throw new InvalidOperationException($"The package '{packageIdentity}' is not compatible with net6.0 or net48.");
if (TryLoadWithFramework(context, framework, installPath, packageReader, out var assemblies))
{
return assemblies;
}
}

return assemblies;
var frameworks = string.Join(", ", CompatibleFrameworks.Select(x => x.GetShortFolderName()));
throw new InvalidOperationException(
$"The package {packageIdentity.Id} {packageIdentity.Version} is not compatible " +
$"with any of the following frameworks: {frameworks}");
}
}

Expand Down
Loading

0 comments on commit 1f95872

Please sign in to comment.