diff --git a/src/Tyres/DataModels/TyreStint.cs b/src/Tyres/DataModels/TyreStint.cs index a990e58..6cc48ee 100644 --- a/src/Tyres/DataModels/TyreStint.cs +++ b/src/Tyres/DataModels/TyreStint.cs @@ -1,4 +1,6 @@ -using Microsoft.ML.Data; +using System; +using Microsoft.ML.Data; +using Microsoft.ML.Transforms; namespace Tyres.DataModels { @@ -26,19 +28,36 @@ public class TyreStint public float Laps; } - public class TransformedTyreStint : TyreStint + public class DistanceTyreStint : TyreStint { public float Distance { get; set; } } public class TyreStintPrediction { - [ColumnName("Score")] + [ColumnName("Score")] public float Distance; } - public class CustomDistanceMapping + public class LapsToDistanceInput + { + public float TrackLength { get; set; } + public float Laps { get; set; } + } + + public class LapsToDistanceOutput { public float Distance { get; set; } } + + [CustomMappingFactoryAttribute(nameof(CustomMappings.DistanceMapping))] + public class CustomMappings : CustomMappingFactory + { + public static void DistanceMapping(LapsToDistanceInput input, LapsToDistanceOutput output) => output.Distance = input.Laps * input.TrackLength; + + public override Action GetMapping() + { + return DistanceMapping; + } + } } diff --git a/src/Tyres/Program.cs b/src/Tyres/Program.cs index 25e902d..8082f59 100644 --- a/src/Tyres/Program.cs +++ b/src/Tyres/Program.cs @@ -18,6 +18,7 @@ public class Program static void Main(string[] args) { var mlContext = new MLContext(seed: 0); + mlContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly); // Load data var data = mlContext.Data.LoadFromTextFile(DatasetsLocation, ';', true); @@ -26,49 +27,33 @@ static void Main(string[] args) var filtered = mlContext.Data.FilterByCustomPredicate(data, (TyreStint row) => !(row.Reason.Equals("Pit Stop") || row.Reason.Equals("Race Finish")) ); var debug = mlContext.Data.CreateEnumerable(filtered, reuseRowObject: false).Count(); + // Transforming to add distancefeature + var customTransformer = mlContext.Transforms.CustomMapping(CustomMappings.DistanceMapping, "LapsToDistance"); + var transformed = customTransformer.Fit(filtered).Transform(filtered); + var debug2 = mlContext.Data.CreateEnumerable(transformed, reuseRowObject: false).ToList(); + // Divide dataset into training and testing data var split = mlContext.Data.TrainTestSplit(filtered, testFraction: 0.1); var trainingData = split.TrainSet; var testingData = split.TestSet; - // Build data pipeline - var pipeline = mlContext.Transforms.CustomMapping((TyreStint input, CustomDistanceMapping output) => output.Distance = input.Laps * input.TrackLength, contractName: null) - .Append(mlContext.Transforms.CopyColumns(outputColumnName: "Label", inputColumnName: nameof(TransformedTyreStint.Distance))) - .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "TeamEncoded", inputColumnName: nameof(TyreStint.Team))) - .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "CarEncoded", inputColumnName: nameof(TyreStint.Car))) - .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "DriverEncoded", inputColumnName: nameof(TyreStint.Driver))) - .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "CompoundEncoded", inputColumnName: nameof(TyreStint.Compound))) - .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TyreStint.AirTemperature))) - .Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TyreStint.TrackTemperature))) - .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "ReasonEncoded", inputColumnName: nameof(TyreStint.Reason))) - .Append(mlContext.Transforms.Concatenate("Features", - "TeamEncoded", "CarEncoded", "DriverEncoded", "CompoundEncoded", nameof(TyreStint.AirTemperature), nameof(TyreStint.TrackTemperature))); - - - /* - // Setting the training algorithm - var trainer = mlContext.Regression.Trainers.Sdca(labelColumnName: "Label", featureColumnName: "Features"); - var trainingPipeline = pipeline.Append(trainer); - - // Training the model - Console.WriteLine("=============== Training the model ==============="); - var trainedModel = trainingPipeline.Fit(trainingData); - - // Evaluate the model on test data - Console.WriteLine("===== Evaluating Model's accuracy with Test data ====="); - - var predictions = trainedModel.Transform(testingData); - var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: "Label", scoreColumnName: "Score"); - PrintRegressionMetrics(trainer, metrics); - - */ - var experimentTime = 600u; + // Run AutoML experiment + var experimentTime = 60u; Console.WriteLine("=============== Training the model ==============="); Console.WriteLine($"Running AutoML regression experiment for {experimentTime} seconds..."); var experimentResult = mlContext.Auto() .CreateRegressionExperiment(experimentTime) - .Execute(trainingData, progressHandler: null, labelColumnName: "Laps"); + .Execute(trainingData, testingData, + columnInformation: new ColumnInformation() + { + CategoricalColumnNames = { nameof(DistanceTyreStint.Team), nameof(DistanceTyreStint.Car), nameof(DistanceTyreStint.Driver), nameof(DistanceTyreStint.Compound), nameof(DistanceTyreStint.Reason) }, + NumericColumnNames = { nameof(DistanceTyreStint.AirTemperature), nameof(DistanceTyreStint.TrackTemperature) }, + LabelColumnName = nameof(DistanceTyreStint.Distance) + }, + preFeaturizer: customTransformer + ); + // Print top models found by AutoML Console.WriteLine(); @@ -80,23 +65,23 @@ static void Main(string[] args) var trainedModel = best.Model; var predictions = trainedModel.Transform(testingData); - var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: "Laps", scoreColumnName: "Score"); + var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: nameof(DistanceTyreStint.Distance), scoreColumnName: "Score"); TrainingHelper.PrintRegressionMetrics(best.TrainerName, metrics); // Run sample predictions var predictionEngine = mlContext.Model.CreatePredictionEngine(trainedModel); - var lh = new TyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Mercedes", Car = "W12", Driver = "Lewis Hamilton", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" }; + var lh = new DistanceTyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Mercedes", Car = "W12", Driver = "Lewis Hamilton", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" }; var lhPred = predictionEngine.Predict(lh); var lhLaps = lhPred.Distance / 5412f; - var mv = new TyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Red Bull", Car = "RB16B", Driver = "Max Verstappen", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" }; + var mv = new DistanceTyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Red Bull", Car = "RB16B", Driver = "Max Verstappen", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" }; var mvPred = predictionEngine.Predict(mv); var mvLaps = mvPred.Distance / 5412f; - + // Printing predictions for top10 grid places var top10Monaco = new List() { new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Charles Leclerc", StartingCompound = "C5"},