From cfce520ca1fb3087287e76cf0099c1e323df7740 Mon Sep 17 00:00:00 2001 From: Michal Lusiak Date: Sun, 23 May 2021 02:09:28 +0200 Subject: [PATCH] Reorganize the code, prepare for Monaco GP --- src/Tyres/DomainModels/Driver.cs | 5 + src/Tyres/Program.cs | 229 +++--------------------- src/Tyres/StaticData/Season2021.cs | 1 + src/Tyres/ViewHelpers/DataHelper.cs | 35 ++++ src/Tyres/ViewHelpers/TrainingHelper.cs | 56 ++++++ 5 files changed, 119 insertions(+), 207 deletions(-) create mode 100644 src/Tyres/ViewHelpers/DataHelper.cs create mode 100644 src/Tyres/ViewHelpers/TrainingHelper.cs diff --git a/src/Tyres/DomainModels/Driver.cs b/src/Tyres/DomainModels/Driver.cs index 71489f6..baa3364 100644 --- a/src/Tyres/DomainModels/Driver.cs +++ b/src/Tyres/DomainModels/Driver.cs @@ -7,4 +7,9 @@ public class Driver public string Car { get; set; } } + + public class Top10Driver : Driver + { + public string StartingCompound { get; set; } + } } \ No newline at end of file diff --git a/src/Tyres/Program.cs b/src/Tyres/Program.cs index 6d9dfc3..25e902d 100644 --- a/src/Tyres/Program.cs +++ b/src/Tyres/Program.cs @@ -7,12 +7,13 @@ using Tyres.DataModels; using Tyres.DomainModels; using Tyres.StaticData; +using Tyres.ViewHelpers; namespace Tyres { public class Program { - private static string DatasetsLocation = @"../../../../../data/Tyres.csv"; + private static string DatasetsLocation = @"../../../../../data/TyreStints.csv"; static void Main(string[] args) { @@ -62,26 +63,25 @@ static void Main(string[] args) */ - uint experimentTime = 120; + var experimentTime = 600u; Console.WriteLine("=============== Training the model ==============="); Console.WriteLine($"Running AutoML regression experiment for {experimentTime} seconds..."); - ExperimentResult experimentResult = mlContext.Auto() + var experimentResult = mlContext.Auto() .CreateRegressionExperiment(experimentTime) .Execute(trainingData, progressHandler: null, labelColumnName: "Laps"); // Print top models found by AutoML Console.WriteLine(); - PrintTopModels(experimentResult); + TrainingHelper.PrintTopModels(experimentResult); Console.WriteLine("===== Evaluating model's accuracy with test data ====="); - RunDetail best = experimentResult.BestRun; + var best = experimentResult.BestRun; - ITransformer trainedModel = best.Model; - IDataView predictions = trainedModel.Transform(testingData); + var trainedModel = best.Model; + var predictions = trainedModel.Transform(testingData); var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: "Laps", scoreColumnName: "Score"); - PrintRegressionMetrics(best.TrainerName, metrics); - + TrainingHelper.PrintRegressionMetrics(best.TrainerName, metrics); // Run sample predictions @@ -96,208 +96,23 @@ static void Main(string[] args) var mvLaps = mvPred.Distance / 5412f; - // Run predictions for Bahrain - var bahrain2021 = new Race() - { - Track = Season2021.Tracks["Bahrain"], - Drivers = Season2021.Drivers, - TyreCompounds = new List() { "C2", "C3", "C4" }, - AirTemperature = 20.5f, - TrackTemperature = 28.3f - }; - //PrintAllPredictions(predictionEngine, bahrain2021); - - PrintTop10Catalunya(predictionEngine); - } - - private static void PrintAllPredictions(PredictionEngine predictionEngine, Race race) - { - foreach (var d in race.Drivers) - { - foreach (var c in race.TyreCompounds) - { - var prediction = predictionEngine.Predict(new TyreStint() - { - Track = race.Track.Name, - TrackLength = race.Track.Distance, - Team = d.Team, - Car = d.Car, - Driver = d.Name, - Compound = c, - AirTemperature = race.AirTemperature, - TrackTemperature = race.TrackTemperature, - Reason = "Pit Stop" - }); - Console.WriteLine($"| {d.Name} | {c} | {prediction.Distance / race.Track.Distance} | |"); - } - } - - } - - private static void PrintTop10Catalunya(PredictionEngine predictionEngine) - { - var Top10Catalunya = new List() + var top10Monaco = new List() { - new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Max Verstappen", StartingCompound = "C3"}, - new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Valtteri Bottas", StartingCompound = "C3"}, - new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Lewis Hamilton", StartingCompound = "C3"}, - new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Carlos Sainz", StartingCompound = "C3"}, - new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Sergio Pérez", StartingCompound = "C3"}, - new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Lando Norris", StartingCompound = "C3"}, - new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Charles Leclerc", StartingCompound = "C3"}, - new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Daniel Ricciardo", StartingCompound = "C3"}, - new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Esteban Ocon", StartingCompound = "C3"}, - new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Fernando Alonso", StartingCompound = "C3"}, + new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Charles Leclerc", StartingCompound = "C5"}, + new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Max Verstappen", StartingCompound = "C5"}, + new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Valtteri Bottas", StartingCompound = "C5"}, + new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Carlos Sainz", StartingCompound = "C5"}, + new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Lando Norris", StartingCompound = "C5"}, + new Top10Driver() {Team = "Toro Rosso / AlphaTauri", Car = "AT02", Name = "Pierre Gasly", StartingCompound = "C5"}, + new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Lewis Hamilton", StartingCompound = "C5"}, + new Top10Driver() {Team = "Force India / Racing Point / Aston Martin", Car = "AMR21", Name = "Sebastian Vettel", StartingCompound = "C5"}, + new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Sergio Pérez", StartingCompound = "C5"}, + new Top10Driver() {Team = "Sauber / Alfa Romeo", Car = "C41", Name = "Antonio Giovinazzi", StartingCompound = "C5"}, + new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Esteban Ocon", StartingCompound = "C5"}, }; - - Console.WriteLine("Catalunya"); - Console.WriteLine("=========="); - foreach (var d in Top10Catalunya) - { - var prediction = predictionEngine.Predict(new TyreStint() - { - Track = Season2021.Tracks["Catalunya"].Name, - TrackLength = Season2021.Tracks["Catalunya"].Distance, - Team = d.Team, - Car = d.Car, - Driver = d.Name, - Compound = d.StartingCompound, - AirTemperature = 25.0f, - TrackTemperature = 43.7f, - Reason = "Pit Stop" - }); - Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance } | | |"); - } - Console.WriteLine(""); + DataHelper.PrintPredictionTable(predictionEngine, "Monaco", 25.0f, 43.7f, top10Monaco); } - - - - private static void PrintTop10ImolaPortimao(PredictionEngine predictionEngine) - { - var ImolaTop10 = new List() - { - new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Lewis Hamilton", StartingCompound = "C3"}, - new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Sergio Pérez", StartingCompound = "C4"}, - new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Max Verstappen", StartingCompound = "C3"}, - new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Charles Leclerc", StartingCompound = "C4"}, - new Top10Driver() {Team = "Toro Rosso / AlphaTauri", Car = "AT02", Name = "Pierre Gasly", StartingCompound = "C4"}, - new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Daniel Ricciardo", StartingCompound = "C4"}, - new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Lando Norris", StartingCompound = "C4"}, - new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Valtteri Bottas", StartingCompound = "C3"}, - new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Esteban Ocon", StartingCompound = "C4"}, - new Top10Driver() {Team = "Force India / Racing Point / Aston Martin", Car = "AMR21", Name = "Lance Stroll", StartingCompound = "C4"}, - }; - - var PortimaoTop10 = new List() - { - new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Valtteri Bottas", StartingCompound = "C2"}, - new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Lewis Hamilton", StartingCompound = "C2"}, - new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Max Verstappen", StartingCompound = "C2"}, - new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Sergio Pérez", StartingCompound = "C2"}, - new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Carlos Sainz", StartingCompound = "C3"}, - new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Esteban Ocon", StartingCompound = "C3"}, - new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Lando Norris", StartingCompound = "C3"}, - new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Charles Leclerc", StartingCompound = "C2"}, - new Top10Driver() {Team = "Toro Rosso / AlphaTauri", Car = "AT02", Name = "Pierre Gasly", StartingCompound = "C3"}, - new Top10Driver() {Team = "Force India / Racing Point / Aston Martin", Car = "AMR21", Name = "Sebastian Vettel", StartingCompound = "C3"}, - }; - - - Console.WriteLine("Imola"); - Console.WriteLine("=========="); - foreach (var d in ImolaTop10) - { - var prediction = predictionEngine.Predict(new TyreStint() - { - Track = Season2021.Tracks["Imola"].Name, - TrackLength = Season2021.Tracks["Imola"].Distance, - Team = d.Team, - Car = d.Car, - Driver = d.Name, - Compound = d.StartingCompound, - AirTemperature = 9.3f, - TrackTemperature = 17.5f, - Reason = "Pit Stop" - }); - Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance / Season2021.Tracks["Imola"].Distance} | | |"); - } - Console.WriteLine(""); - - Console.WriteLine("Portimão"); - Console.WriteLine("=========="); - foreach (var d in PortimaoTop10) - { - var prediction = predictionEngine.Predict(new TyreStint() - { - Track = Season2021.Tracks["Portimão"].Name, - TrackLength = Season2021.Tracks["Portimão"].Distance, - Team = d.Team, - Car = d.Car, - Driver = d.Name, - Compound = d.StartingCompound, - AirTemperature = 19.8f, - TrackTemperature = 40.3f, - Reason = "Pit Stop" - }); - Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance / Season2021.Tracks["Portimão"].Distance} | | |"); - } - } - - private static void PrintTopModels(ExperimentResult experimentResult) - { - // Get top few runs ranked by R-Squared. - // R-Squared is a metric to maximize, so OrderByDescending() is correct. - // For RMSE and other regression metrics, OrderByAscending() is correct. - var topRuns = experimentResult.RunDetails - .Where(r => r.ValidationMetrics != null && !double.IsNaN(r.ValidationMetrics.RSquared)) - .OrderByDescending(r => r.ValidationMetrics.RSquared).Take(3); - - Console.WriteLine("Top models ranked by R-Squared --"); - PrintRegressionMetricsHeader(); - for (var i = 0; i < topRuns.Count(); i++) - { - var run = topRuns.ElementAt(i); - PrintIterationMetrics(i + 1, run.TrainerName, run.ValidationMetrics, run.RuntimeInSeconds); - } - } - - public static void PrintRegressionMetrics(string name, RegressionMetrics metrics) - { - Console.WriteLine($"*************************************************"); - Console.WriteLine($"* Metrics for {name} regression model "); - Console.WriteLine($"*------------------------------------------------"); - Console.WriteLine($"* LossFn: {metrics.LossFunction:0.##}"); - Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}"); - Console.WriteLine($"* Absolute loss: {metrics.MeanAbsoluteError:#.##}"); - Console.WriteLine($"* Squared loss: {metrics.MeanSquaredError:#.##}"); - Console.WriteLine($"* RMS loss: {metrics.RootMeanSquaredError:#.##}"); - Console.WriteLine($"*************************************************"); - } - - internal static void PrintRegressionMetricsHeader() - { - CreateRow($"{"",-4} {"Trainer",-35} {"RSquared",8} {"Absolute-loss",13} {"Squared-loss",12} {"RMS-loss",8} {"Duration",9}", 114); - } - - internal static void PrintIterationMetrics(int iteration, string trainerName, RegressionMetrics metrics, double? runtimeInSeconds) - { - CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,8:F4} {metrics?.MeanAbsoluteError ?? double.NaN,13:F2} {metrics?.MeanSquaredError ?? double.NaN,12:F2} {metrics?.RootMeanSquaredError ?? double.NaN,8:F2} {runtimeInSeconds.Value,9:F1}", 114); - } - - private static void CreateRow(string message, int width) - { - Console.WriteLine("|" + message.PadRight(width - 2) + "|"); - } - - - - } - - public class Top10Driver : Driver - { - public string StartingCompound { get; set; } } } diff --git a/src/Tyres/StaticData/Season2021.cs b/src/Tyres/StaticData/Season2021.cs index 969cfc1..8810140 100644 --- a/src/Tyres/StaticData/Season2021.cs +++ b/src/Tyres/StaticData/Season2021.cs @@ -37,6 +37,7 @@ public static class Season2021 { "Imola", new Track() {Name = "Imola", Distance = 4909f } }, { "Portimão", new Track() {Name = "Portimão", Distance = 4653f } }, { "Catalunya", new Track() {Name = "Catalunya", Distance = 4675f } }, + { "Monaco", new Track() {Name = "Monaco", Distance = 3337f } } }; } } diff --git a/src/Tyres/ViewHelpers/DataHelper.cs b/src/Tyres/ViewHelpers/DataHelper.cs new file mode 100644 index 0000000..9ac5415 --- /dev/null +++ b/src/Tyres/ViewHelpers/DataHelper.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Tyres.DataModels; +using Tyres.DomainModels; +using Tyres.StaticData; + +namespace Tyres.ViewHelpers +{ + public class DataHelper + { + public static void PrintPredictionTable(PredictionEngine predictionEngine, string trackName, float airTemperature, float trackTemperature, List drivers) + { + Console.WriteLine(Season2021.Tracks[trackName].Name); + Console.WriteLine("=========="); + foreach (var d in drivers) + { + var prediction = predictionEngine.Predict(new TyreStint() + { + Track = Season2021.Tracks[trackName].Name, + TrackLength = Season2021.Tracks[trackName].Distance, + Team = d.Team, + Car = d.Car, + Driver = d.Name, + Compound = d.StartingCompound, + AirTemperature = airTemperature, + TrackTemperature = trackTemperature, + Reason = "Pit Stop" + }); + Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance } | | |"); + } + Console.WriteLine(""); + } + } +} \ No newline at end of file diff --git a/src/Tyres/ViewHelpers/TrainingHelper.cs b/src/Tyres/ViewHelpers/TrainingHelper.cs new file mode 100644 index 0000000..0f0f10f --- /dev/null +++ b/src/Tyres/ViewHelpers/TrainingHelper.cs @@ -0,0 +1,56 @@ +using System; +using System.Linq; +using Microsoft.ML.AutoML; +using Microsoft.ML.Data; + +namespace Tyres.ViewHelpers +{ + public class TrainingHelper + { + public static void PrintTopModels(ExperimentResult experimentResult) + { + // Get top few runs ranked by R-Squared. + // R-Squared is a metric to maximize, so OrderByDescending() is correct. + // For RMSE and other regression metrics, OrderByAscending() is correct. + var topRuns = experimentResult.RunDetails + .Where(r => r.ValidationMetrics != null && !double.IsNaN(r.ValidationMetrics.RSquared)) + .OrderByDescending(r => r.ValidationMetrics.RSquared).Take(3); + + Console.WriteLine("Top models ranked by R-Squared --"); + PrintRegressionMetricsHeader(); + for (var i = 0; i < topRuns.Count(); i++) + { + var run = topRuns.ElementAt(i); + PrintIterationMetrics(i + 1, run.TrainerName, run.ValidationMetrics, run.RuntimeInSeconds); + } + } + + public static void PrintRegressionMetrics(string name, RegressionMetrics metrics) + { + Console.WriteLine($"*************************************************"); + Console.WriteLine($"* Metrics for {name} regression model "); + Console.WriteLine($"*------------------------------------------------"); + Console.WriteLine($"* LossFn: {metrics.LossFunction:0.##}"); + Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}"); + Console.WriteLine($"* Absolute loss: {metrics.MeanAbsoluteError:#.##}"); + Console.WriteLine($"* Squared loss: {metrics.MeanSquaredError:#.##}"); + Console.WriteLine($"* RMS loss: {metrics.RootMeanSquaredError:#.##}"); + Console.WriteLine($"*************************************************"); + } + + internal static void PrintRegressionMetricsHeader() + { + CreateRow($"{"",-4} {"Trainer",-35} {"RSquared",8} {"Absolute-loss",13} {"Squared-loss",12} {"RMS-loss",8} {"Duration",9}", 114); + } + + internal static void PrintIterationMetrics(int iteration, string trainerName, RegressionMetrics metrics, double? runtimeInSeconds) + { + CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,8:F4} {metrics?.MeanAbsoluteError ?? double.NaN,13:F2} {metrics?.MeanSquaredError ?? double.NaN,12:F2} {metrics?.RootMeanSquaredError ?? double.NaN,8:F2} {runtimeInSeconds.Value,9:F1}", 114); + } + + private static void CreateRow(string message, int width) + { + Console.WriteLine("|" + message.PadRight(width - 2) + "|"); + } + } +}