diff --git a/src/Tyres/DataModels/TyreStint.cs b/src/Tyres/DataModels/TyreStint.cs index 6cc48ee..0af6973 100644 --- a/src/Tyres/DataModels/TyreStint.cs +++ b/src/Tyres/DataModels/TyreStint.cs @@ -24,13 +24,8 @@ public class TyreStint public float TrackTemperature; [LoadColumn(10)] public string Reason; - [LoadColumn(11)] - public float Laps; - } - - public class DistanceTyreStint : TyreStint - { - public float Distance { get; set; } + [LoadColumn(12)] + public float Distance; } public class TyreStintPrediction @@ -38,26 +33,4 @@ public class TyreStintPrediction [ColumnName("Score")] public float Distance; } - - public class LapsToDistanceInput - { - public float TrackLength { get; set; } - public float Laps { get; set; } - } - - public class LapsToDistanceOutput - { - public float Distance { get; set; } - } - - [CustomMappingFactoryAttribute(nameof(CustomMappings.DistanceMapping))] - public class CustomMappings : CustomMappingFactory - { - public static void DistanceMapping(LapsToDistanceInput input, LapsToDistanceOutput output) => output.Distance = input.Laps * input.TrackLength; - - public override Action GetMapping() - { - return DistanceMapping; - } - } } diff --git a/src/Tyres/DomainModels/Track.cs b/src/Tyres/DomainModels/Track.cs index 3014ffd..b4bda1c 100644 --- a/src/Tyres/DomainModels/Track.cs +++ b/src/Tyres/DomainModels/Track.cs @@ -5,7 +5,7 @@ namespace Tyres.DomainModels public class Track { public string Name { get; set; } - public float Distance { get; set; } + public float TrackLength { get; set; } } public class Race diff --git a/src/Tyres/Program.cs b/src/Tyres/Program.cs index 8082f59..331b903 100644 --- a/src/Tyres/Program.cs +++ b/src/Tyres/Program.cs @@ -18,7 +18,6 @@ public class Program static void Main(string[] args) { var mlContext = new MLContext(seed: 0); - mlContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly); // Load data var data = mlContext.Data.LoadFromTextFile(DatasetsLocation, ';', true); @@ -27,11 +26,6 @@ static void Main(string[] args) var filtered = mlContext.Data.FilterByCustomPredicate(data, (TyreStint row) => !(row.Reason.Equals("Pit Stop") || row.Reason.Equals("Race Finish")) ); var debug = mlContext.Data.CreateEnumerable(filtered, reuseRowObject: false).Count(); - // Transforming to add distancefeature - var customTransformer = mlContext.Transforms.CustomMapping(CustomMappings.DistanceMapping, "LapsToDistance"); - var transformed = customTransformer.Fit(filtered).Transform(filtered); - var debug2 = mlContext.Data.CreateEnumerable(transformed, reuseRowObject: false).ToList(); - // Divide dataset into training and testing data var split = mlContext.Data.TrainTestSplit(filtered, testFraction: 0.1); var trainingData = split.TrainSet; @@ -39,7 +33,7 @@ static void Main(string[] args) // Run AutoML experiment - var experimentTime = 60u; + var experimentTime = 900u; Console.WriteLine("=============== Training the model ==============="); Console.WriteLine($"Running AutoML regression experiment for {experimentTime} seconds..."); var experimentResult = mlContext.Auto() @@ -47,16 +41,14 @@ static void Main(string[] args) .Execute(trainingData, testingData, columnInformation: new ColumnInformation() { - CategoricalColumnNames = { nameof(DistanceTyreStint.Team), nameof(DistanceTyreStint.Car), nameof(DistanceTyreStint.Driver), nameof(DistanceTyreStint.Compound), nameof(DistanceTyreStint.Reason) }, - NumericColumnNames = { nameof(DistanceTyreStint.AirTemperature), nameof(DistanceTyreStint.TrackTemperature) }, - LabelColumnName = nameof(DistanceTyreStint.Distance) - }, - preFeaturizer: customTransformer - ); + CategoricalColumnNames = { nameof(TyreStint.Team), nameof(TyreStint.Car), nameof(TyreStint.Driver), nameof(TyreStint.Compound), nameof(TyreStint.Reason) }, + NumericColumnNames = { nameof(TyreStint.AirTemperature), nameof(TyreStint.TrackTemperature) }, + LabelColumnName = nameof(TyreStint.Distance) + } + ); // Print top models found by AutoML - Console.WriteLine(); TrainingHelper.PrintTopModels(experimentResult); Console.WriteLine("===== Evaluating model's accuracy with test data ====="); @@ -65,18 +57,18 @@ static void Main(string[] args) var trainedModel = best.Model; var predictions = trainedModel.Transform(testingData); - var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: nameof(DistanceTyreStint.Distance), scoreColumnName: "Score"); + var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: nameof(TyreStint.Distance), scoreColumnName: "Score"); TrainingHelper.PrintRegressionMetrics(best.TrainerName, metrics); // Run sample predictions var predictionEngine = mlContext.Model.CreatePredictionEngine(trainedModel); - var lh = new DistanceTyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Mercedes", Car = "W12", Driver = "Lewis Hamilton", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" }; + var lh = new TyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Mercedes", Car = "W12", Driver = "Lewis Hamilton", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" }; var lhPred = predictionEngine.Predict(lh); var lhLaps = lhPred.Distance / 5412f; - var mv = new DistanceTyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Red Bull", Car = "RB16B", Driver = "Max Verstappen", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" }; + var mv = new TyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Red Bull", Car = "RB16B", Driver = "Max Verstappen", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" }; var mvPred = predictionEngine.Predict(mv); var mvLaps = mvPred.Distance / 5412f; @@ -97,7 +89,7 @@ static void Main(string[] args) new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Esteban Ocon", StartingCompound = "C5"}, }; - DataHelper.PrintPredictionTable(predictionEngine, "Monaco", 25.0f, 43.7f, top10Monaco); + DataHelper.PrintPredictionTable(predictionEngine, "Monaco", 20.0f, 35.0f, top10Monaco); } } } diff --git a/src/Tyres/StaticData/Season2021.cs b/src/Tyres/StaticData/Season2021.cs index 8810140..3c835c5 100644 --- a/src/Tyres/StaticData/Season2021.cs +++ b/src/Tyres/StaticData/Season2021.cs @@ -33,11 +33,11 @@ public static class Season2021 public static Dictionary Tracks = new Dictionary() { - { "Bahrain", new Track() {Name = "Bahrain International Circuit", Distance = 5412f} }, - { "Imola", new Track() {Name = "Imola", Distance = 4909f } }, - { "Portimão", new Track() {Name = "Portimão", Distance = 4653f } }, - { "Catalunya", new Track() {Name = "Catalunya", Distance = 4675f } }, - { "Monaco", new Track() {Name = "Monaco", Distance = 3337f } } + { "Bahrain", new Track() {Name = "Bahrain International Circuit", TrackLength = 5412f} }, + { "Imola", new Track() {Name = "Imola", TrackLength = 4909f } }, + { "Portimão", new Track() {Name = "Portimão", TrackLength = 4653f } }, + { "Catalunya", new Track() {Name = "Catalunya", TrackLength = 4675f } }, + { "Monaco", new Track() {Name = "Monaco", TrackLength = 3337f } } }; } } diff --git a/src/Tyres/ViewHelpers/DataHelper.cs b/src/Tyres/ViewHelpers/DataHelper.cs index 9ac5415..1b4fce3 100644 --- a/src/Tyres/ViewHelpers/DataHelper.cs +++ b/src/Tyres/ViewHelpers/DataHelper.cs @@ -11,14 +11,16 @@ public class DataHelper { public static void PrintPredictionTable(PredictionEngine predictionEngine, string trackName, float airTemperature, float trackTemperature, List drivers) { - Console.WriteLine(Season2021.Tracks[trackName].Name); + var monaco = Season2021.Tracks[trackName]; + + Console.WriteLine(monaco.Name); Console.WriteLine("=========="); foreach (var d in drivers) { var prediction = predictionEngine.Predict(new TyreStint() { - Track = Season2021.Tracks[trackName].Name, - TrackLength = Season2021.Tracks[trackName].Distance, + Track = monaco.Name, + TrackLength = monaco.TrackLength, Team = d.Team, Car = d.Car, Driver = d.Name, @@ -27,7 +29,7 @@ public static void PrintPredictionTable(PredictionEngine experimentResult) { + Console.WriteLine(); // Get top few runs ranked by R-Squared. // R-Squared is a metric to maximize, so OrderByDescending() is correct. // For RMSE and other regression metrics, OrderByAscending() is correct.