Skip to content

Troubles with CustomMappingEstimator after save/load because InputSchemaDefinition is not being saved. #3988

Open
@SkinnyMan32

Description

@SkinnyMan32

System information

  • ML.Net v1.2.0

Issue

I try to use CustomMapping with specified column names. It works fine, but after save/load model I get exception:
System.ArgumentOutOfRangeException: "Could not find column 'Words'".
'Words' - the original name of property, not specified by me.

What am i doing wrong?

I can fix it by using only original names, it is not comfortable in some cases.

Source code

// --- Test method ---
var ml = new MLContext();
var descriptions = new[]
{
	new { Description = "Painted, Painting, Painter" }
};
var dataView = ml.Data.LoadFromEnumerable(descriptions);

var pipeline = ml.Transforms.Text.NormalizeText("Normalized", "Description")
	.Append(ml.Transforms.Text.TokenizeIntoWords("Tokens", "Normalized"))
     // (Extension method) CustomMapping with specified column names
	.Append(ml.Transforms.StemText("Stemmed", "Tokens"));

var model = pipeline.Fit(dataView);
var preview = model.Transform(dataView).Preview();  // everything is ok

// Save model
MemoryStream stream = new MemoryStream();
ml.Model.Save(model, dataView.Schema, stream);
stream.Position = 0;

// Load model in the new context
var ml2 = new MLContext();
// Register custom action
ml2.ComponentCatalog.RegisterAssembly(typeof(StemmerCustomAction).Assembly);
var loadedModel = ml2.Model.Load(stream, out var schema);

// Exception:
// System.ArgumentOutOfRangeException: "Could not find  column 'Words'
var preview2 = loadedModel.Transform(dataView).Preview();


//--- Classes ---

public class StemmerInput
{
	public string[] Words { get; set; }
}

public class StemmerOutput
{
	public string[] Stemmed { get; set; }
}

[CustomMappingFactoryAttribute("StemText")]
public class StemmerCustomAction : CustomMappingFactory<StemmerInput, StemmerOutput>
{
	public static void StemAction(StemmerInput input, StemmerOutput output)
	{
		var stemmer = new EnglishStemmer();
		output.Stemmed = new string[input.Words.Length];
		for (int i = 0; i < input.Words.Length; i++)
		{
			output.Stemmed[i] = stemmer.Stem(input.Words[i]);
		}
	}

	public override Action<StemmerInput, StemmerOutput> GetMapping() => StemAction;
}

static class StemmerTransformHelper
{
	public static CustomMappingEstimator<StemmerInput, StemmerOutput> StemText(this TransformsCatalog catalog,
		string outputColumnName, string inputColumnName = null)
	{
		var inputSchema = SchemaDefinition.Create(typeof(StemmerInput), SchemaDefinition.Direction.Read);		
		var outSchema = SchemaDefinition.Create(typeof(StemmerOutput), SchemaDefinition.Direction.Write);		
		// specify column names
		inputSchema[0].ColumnName = inputColumnName ?? outputColumnName;
		outSchema[0].ColumnName = outputColumnName;
		return catalog.CustomMapping(new StemmerCustomAction().GetMapping(), "StemText", inputSchema, outSchema);
	}
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    API breaking changeThe change(s) required for this issue may break the current APIP1Priority of the issue for triage purpose: Needs to be fixed soon.bugSomething isn't workingloadsaveBugs related loading and saving data or models

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions