Skip to content

Commit 380f167

Browse files
authored
Forecasting model framework for time series. (#1900)
* Forecasting interface with a unit-test. * Comments. * cleanup. * cleanup. * Internalize forecasting model and expose bare minimum needed and merge from master. * PR feedback. * PR feedback. * PR feedback. * Comments. * Samples. * Samples. * Comments. * Fix sample output. * clean up. * clean up. * docs and cleanup. * typo in doc. * PR feedback.
1 parent 33e45ba commit 380f167

File tree

9 files changed

+1831
-1211
lines changed

9 files changed

+1831
-1211
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Microsoft.ML;
4+
using Microsoft.ML.Transforms.TimeSeries;
5+
using Microsoft.ML.TimeSeries;
6+
7+
namespace Samples.Dynamic
8+
{
9+
public static class Forecasting
10+
{
11+
// This example creates a time series (list of Data with the i-th element corresponding to the i-th time slot) and then
12+
// does forecasting.
13+
public static void Example()
14+
{
15+
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
16+
// as well as the source of randomness.
17+
var ml = new MLContext();
18+
19+
// Generate sample series data with a recurring pattern
20+
const int SeasonalitySize = 5;
21+
var data = new List<TimeSeriesData>()
22+
{
23+
new TimeSeriesData(0),
24+
new TimeSeriesData(1),
25+
new TimeSeriesData(2),
26+
new TimeSeriesData(3),
27+
new TimeSeriesData(4),
28+
29+
new TimeSeriesData(0),
30+
new TimeSeriesData(1),
31+
new TimeSeriesData(2),
32+
new TimeSeriesData(3),
33+
new TimeSeriesData(4),
34+
35+
new TimeSeriesData(0),
36+
new TimeSeriesData(1),
37+
new TimeSeriesData(2),
38+
new TimeSeriesData(3),
39+
new TimeSeriesData(4),
40+
};
41+
42+
// Convert data to IDataView.
43+
var dataView = ml.Data.LoadFromEnumerable(data);
44+
45+
// Setup arguments.
46+
var inputColumnName = nameof(TimeSeriesData.Value);
47+
48+
// Instantiate the forecasting model.
49+
var model = ml.Forecasting.AdaptiveSingularSpectrumSequenceModeler(inputColumnName, data.Count, SeasonalitySize + 1, SeasonalitySize,
50+
1, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalitySize / 2, false, false);
51+
52+
// Train.
53+
model.Train(dataView);
54+
55+
// Forecast next five values.
56+
var forecast = model.Forecast(5);
57+
Console.WriteLine($"Forecasted values:");
58+
Console.WriteLine("[{0}]", string.Join(", ", forecast));
59+
// Forecasted values:
60+
// [2.452744, 2.589339, 2.729183, 2.873005, 3.028931]
61+
62+
// Update with new observations.
63+
dataView = ml.Data.LoadFromEnumerable(new List<TimeSeriesData>() { new TimeSeriesData(0), new TimeSeriesData(0), new TimeSeriesData(0), new TimeSeriesData(0) });
64+
model.Update(dataView);
65+
66+
// Checkpoint.
67+
ml.Model.SaveForecastingModel(model, "model.zip");
68+
69+
// Load the checkpointed model from disk.
70+
var modelCopy = ml.Model.LoadForecastingModel<float>("model.zip");
71+
72+
// Forecast with the checkpointed model loaded from disk.
73+
forecast = modelCopy.Forecast(5);
74+
Console.WriteLine("[{0}]", string.Join(", ", forecast));
75+
// [0.8681176, 0.8185108, 0.8069275, 0.84405, 0.9455081]
76+
77+
// Forecast with the original model(that was checkpointed to disk).
78+
forecast = model.Forecast(5);
79+
Console.WriteLine("[{0}]", string.Join(", ", forecast));
80+
// [0.8681176, 0.8185108, 0.8069275, 0.84405, 0.9455081]
81+
82+
}
83+
84+
class TimeSeriesData
85+
{
86+
public float Value;
87+
88+
public TimeSeriesData(float value)
89+
{
90+
Value = value;
91+
}
92+
}
93+
}
94+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Microsoft.ML;
4+
using Microsoft.ML.Transforms.TimeSeries;
5+
using Microsoft.ML.TimeSeries;
6+
7+
namespace Samples.Dynamic
8+
{
9+
public static class ForecastingWithConfidenceInternal
10+
{
11+
// This example creates a time series (list of Data with the i-th element corresponding to the i-th time slot) and then
12+
// does forecasting.
13+
public static void Example()
14+
{
15+
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
16+
// as well as the source of randomness.
17+
var ml = new MLContext();
18+
19+
// Generate sample series data with a recurring pattern
20+
const int SeasonalitySize = 5;
21+
var data = new List<TimeSeriesData>()
22+
{
23+
new TimeSeriesData(0),
24+
new TimeSeriesData(1),
25+
new TimeSeriesData(2),
26+
new TimeSeriesData(3),
27+
new TimeSeriesData(4),
28+
29+
new TimeSeriesData(0),
30+
new TimeSeriesData(1),
31+
new TimeSeriesData(2),
32+
new TimeSeriesData(3),
33+
new TimeSeriesData(4),
34+
35+
new TimeSeriesData(0),
36+
new TimeSeriesData(1),
37+
new TimeSeriesData(2),
38+
new TimeSeriesData(3),
39+
new TimeSeriesData(4),
40+
};
41+
42+
// Convert data to IDataView.
43+
var dataView = ml.Data.LoadFromEnumerable(data);
44+
45+
// Setup arguments.
46+
var inputColumnName = nameof(TimeSeriesData.Value);
47+
48+
// Instantiate forecasting model.
49+
var model = ml.Forecasting.AdaptiveSingularSpectrumSequenceModeler(inputColumnName, data.Count, SeasonalitySize + 1, SeasonalitySize,
50+
1, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalitySize / 2, shouldComputeForecastIntervals: true, false);
51+
52+
// Train.
53+
model.Train(dataView);
54+
55+
// Forecast next five values with confidence internal.
56+
float[] forecast;
57+
float[] confidenceIntervalLowerBounds;
58+
float[] confidenceIntervalUpperBounds;
59+
model.ForecastWithConfidenceIntervals(5, out forecast, out confidenceIntervalLowerBounds, out confidenceIntervalUpperBounds);
60+
PrintForecastValuesAndIntervals(forecast, confidenceIntervalLowerBounds, confidenceIntervalUpperBounds);
61+
// Forecasted values:
62+
// [2.452744, 2.589339, 2.729183, 2.873005, 3.028931]
63+
// Confidence intervals:
64+
// [-0.2235315 - 5.12902] [-0.08777174 - 5.266451] [0.05076938 - 5.407597] [0.1925406 - 5.553469] [0.3469928 - 5.71087]
65+
66+
// Update with new observations.
67+
dataView = ml.Data.LoadFromEnumerable(new List<TimeSeriesData>() { new TimeSeriesData(0), new TimeSeriesData(0), new TimeSeriesData(0), new TimeSeriesData(0) });
68+
model.Update(dataView);
69+
70+
// Checkpoint.
71+
ml.Model.SaveForecastingModel(model, "model.zip");
72+
73+
// Load the checkpointed model from disk.
74+
var modelCopy = ml.Model.LoadForecastingModel<float>("model.zip");
75+
76+
// Forecast with the checkpointed model loaded from disk.
77+
modelCopy.ForecastWithConfidenceIntervals(5, out forecast, out confidenceIntervalLowerBounds, out confidenceIntervalUpperBounds);
78+
PrintForecastValuesAndIntervals(forecast, confidenceIntervalLowerBounds, confidenceIntervalUpperBounds);
79+
// Forecasted values:
80+
// [0.8681176, 0.8185108, 0.8069275, 0.84405, 0.9455081]
81+
// Confidence intervals:
82+
// [-1.808158 - 3.544394] [-1.8586 - 3.495622] [-1.871486 - 3.485341] [-1.836414 - 3.524514] [-1.736431 - 3.627447]
83+
84+
// Forecast with the original model(that was checkpointed to disk).
85+
model.ForecastWithConfidenceIntervals(5, out forecast, out confidenceIntervalLowerBounds, out confidenceIntervalUpperBounds);
86+
PrintForecastValuesAndIntervals(forecast, confidenceIntervalLowerBounds, confidenceIntervalUpperBounds);
87+
// Forecasted values:
88+
// [0.8681176, 0.8185108, 0.8069275, 0.84405, 0.9455081]
89+
// Confidence intervals:
90+
// [-1.808158 - 3.544394] [-1.8586 - 3.495622] [-1.871486 - 3.485341] [-1.836414 - 3.524514] [-1.736431 - 3.627447]
91+
}
92+
93+
static void PrintForecastValuesAndIntervals(float[] forecast, float[] confidenceIntervalLowerBounds, float[] confidenceIntervalUpperBounds)
94+
{
95+
Console.WriteLine($"Forecasted values:");
96+
Console.WriteLine("[{0}]", string.Join(", ", forecast));
97+
Console.WriteLine($"Confidence intervals:");
98+
for (int index = 0; index < forecast.Length; index++)
99+
Console.Write($"[{confidenceIntervalLowerBounds[index]} - {confidenceIntervalUpperBounds[index]}] ");
100+
Console.WriteLine();
101+
}
102+
103+
class TimeSeriesData
104+
{
105+
public float Value;
106+
107+
public TimeSeriesData(float value)
108+
{
109+
Value = value;
110+
}
111+
}
112+
}
113+
}

src/Microsoft.ML.Data/MLContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ public sealed class MLContext : IHostEnvironment
4646
/// </summary>
4747
public AnomalyDetectionCatalog AnomalyDetection { get; }
4848

49+
/// <summary>
50+
/// Trainers and tasks specific to forecasting problems.
51+
/// </summary>
52+
public ForecastingCatalog Forecasting { get; }
53+
4954
/// <summary>
5055
/// Data processing operations.
5156
/// </summary>
@@ -89,6 +94,7 @@ public MLContext(int? seed = null)
8994
Clustering = new ClusteringCatalog(_env);
9095
Ranking = new RankingCatalog(_env);
9196
AnomalyDetection = new AnomalyDetectionCatalog(_env);
97+
Forecasting = new ForecastingCatalog(_env);
9298
Transforms = new TransformsCatalog(_env);
9399
Model = new ModelOperationsCatalog(_env);
94100
Data = new DataOperationsCatalog(_env);

src/Microsoft.ML.Data/TrainCatalog.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,4 +704,31 @@ public AnomalyDetectionMetrics Evaluate(IDataView data, string labelColumnName =
704704
return eval.Evaluate(data, labelColumnName, scoreColumnName, predictedLabelColumnName);
705705
}
706706
}
707+
708+
/// <summary>
709+
/// Class used by <see cref="MLContext"/> to create instances of forecasting components.
710+
/// </summary>
711+
public sealed class ForecastingCatalog : TrainCatalogBase
712+
{
713+
/// <summary>
714+
/// The list of trainers for performing forecasting.
715+
/// </summary>
716+
public Forecasters Trainers { get; }
717+
718+
internal ForecastingCatalog(IHostEnvironment env) : base(env, nameof(ForecastingCatalog))
719+
{
720+
Trainers = new Forecasters(this);
721+
}
722+
723+
/// <summary>
724+
/// Class used by <see cref="MLContext"/> to create instances of forecasting trainers.
725+
/// </summary>
726+
public sealed class Forecasters : CatalogInstantiatorBase
727+
{
728+
internal Forecasters(ForecastingCatalog catalog)
729+
: base(catalog)
730+
{
731+
}
732+
}
733+
}
707734
}

0 commit comments

Comments
 (0)