diff --git a/scripts/perftest/ohe_checks/README.md b/scripts/perftest/ohe_checks/README.md
new file mode 100644
index 00000000000..bdf0df4eaa4
--- /dev/null
+++ b/scripts/perftest/ohe_checks/README.md
@@ -0,0 +1,121 @@
+
+
+# Checking One Hot Encodedness before Compression tests
+
+To run all tests for One Hot Encoding Checks:
+ * install systemds,
+ * make sure that the paths for SYSTEMDS_ROOT, JAVA_HOME, HADOOP_HOME are correctly set
+ * set the path for LOG4JPROP to `$SYSTEMDS_ROOT/scripts/perftest/ohe_checks/log4j-compression.properties`
+ * run experiments.sh
+
+Alternatively, to run the experiment.dml script directly with OHE checks enabled, use this command:
+
+`$SYSTEMDS_ROOT/bin/systemds $SYSTEMDS_ROOT/target/SystemDS.jar experiment.dml --config ohe.xml `
+
+Note: You can use -nvargs to set the variables rows, cols, dummy, distinct, repeats (how many times you want to generate a random matrix, transform-encode it and compress it)
+
+(Dummy is the array of column indexes that you would like to One Hot Encode, example: dummy="[1]" will One Hot Encode the first column)
+
+To collect the metrics from the logs for easier comparison, you can run `parse_logs.py` and an excel file called `combined_metrics.xlsx` will be created in this directory.
+---
+# Documentation of Changes to codebase for Implementing OHE Checks
+
+## Flag to enable/disable OHE checks (Disabled by default)
+- Added ``COMPRESSED_ONEHOTDETECT = "sysds.compressed.onehotdetect"`` to ``DMLConfig`` and adjusted the relevant methods
+- Added attribute to ``CompressionSettings`` ``public final boolean oneHotDetect`` and adjusted the methods
+- Adjusted ``CompressionSettingsBuilder`` to check if ``COMPRESSED_ONEHOTDETECT`` has been set to true to enable the checks
+
+## Changes in `CoCoderFactory`
+
+### 1. Introduction of OHE Detection
+
+**Condition Addition:**
+- Added a condition to check for `cs.oneHotDetect` along with the existing condition `!containsEmptyConstOrIncompressable` in the `findCoCodesByPartitioning` method. This ensures that the process considers OHE detection only if it is enabled in the compression settings.
+- Original code only checked for `containsEmptyConstOrIncompressable` and proceeded to cocode all columns if false. The updated code includes an additional check for `cs.oneHotDetect`.
+
+### 2. New Data Structures for OHE Handling
+
+**New Lists:** Introduced two new lists to manage the OHE detection process:
+- `currentCandidates`: To store the current candidate columns that might form an OHE group.
+- `oheGroups`: To store lists of columns that have been validated as OHE groups.
+
+### 3. Filtering Logic Enhancements
+
+**Column Filtering:** Enhanced the loop that iterates over columns to identify OHE candidates:
+- Columns that are empty, constant, or incompressible are filtered into respective lists.
+- For other columns, they are added to `currentCandidates` if they are deemed candidates (via `isCandidate` function).
+
+### 4. Addition of `isHotEncoded` Function
+
+**Function Creation:** Created a new `isHotEncoded` function to evaluate if the accumulated columns form a valid OHE group.
+- **Parameters:** Takes a list of column groups (`colGroups`), a boolean flag (`isSample`), an array of non-zero counts (`nnzCols`), and the number of rows (`numRows`).
+- **Return Type:** Returns a `String` indicating the status of the current candidates:
+ - `"POTENTIAL_OHE"`: When the current candidates could still form an OHE group.
+ - `"NOT_OHE"`: When the current candidates cannot form an OHE group.
+ - `"VALID_OHE"`: When the current candidates form a valid OHE group.
+- **Logic:** The function calculates the total number of distinct values and offsets, and checks if they meet the criteria for forming an OHE group.
+
+### 5. Enhanced Group Handling
+
+**Candidate Processing:** Within the loop, after adding a column to `currentCandidates`:
+- Calls `isHotEncoded` to check the status of the candidates.
+- If `isHotEncoded` returns `"NOT_OHE"`, moves the candidates to regular groups and clears the candidates list.
+- If `isHotEncoded` returns `"VALID_OHE"`, moves the candidates to `oheGroups` and clears the candidates list.
+- If `isHotEncoded` returns `"POTENTIAL_OHE"`, continues accumulating candidates.
+
+### 6. Final Candidate Check
+
+**Post-loop Check:** After the loop, checks any remaining `currentCandidates`:
+- If they form a valid OHE group, adds them to `oheGroups`.
+- Otherwise, adds them to regular groups.
+
+### 7. Overwrite and CoCode Groups
+
+**Overwrite Groups:** Updates `colInfos.compressionInfo` with the processed `groups`.
+**OHE Group Integration:** Combines indexes for validated OHE groups and adds them to the final `groups`.
+
+## One Hot Encoded Columns Compression in `ColGroupFactory`
+
+### Description
+
+The `compressOHE` function is designed to compress columns that are one-hot encoded (OHE). It validates and processes the input data to ensure it meets the criteria for one-hot encoding, and if so, it compresses the data accordingly. If the data does not meet the OHE criteria, it falls back to a direct compression method (`directCompressDDC`).
+
+### Implementation Details
+
+1. **Validation of `numVals`**:
+ - Ensures the number of distinct values (`numVals`) in the column group is greater than 0.
+ - Throws a `DMLCompressionException` if `numVals` is less than or equal to 0.
+
+2. **Handling Transposed Matrix**:
+ - If the matrix is transposed (`cs.transposed` is `true`):
+ - Creates a `MapToFactory` data structure with an additional unique value.
+ - Iterates through the sparse block of the matrix, checking for non-one values or multiple ones in the same row.
+ - If a column index in the sparse block is empty, or if non-one values or multiple ones are found, it falls back to `directCompressDDC`.
+
+3. **Handling Non-Transposed Matrix**:
+ - If the matrix is not transposed (`cs.transposed` is `false`):
+ - Creates a `MapToFactory` data structure.
+ - Iterates through each row of the matrix:
+ - Checks for the presence of exactly one '1' in the columns specified by `colIndexes`.
+ - If multiple ones are found in the same row, or if no '1' is found in a sample row, it falls back to `directCompressDDC`.
+
+4. **Return Value**:
+ - If the data meets the OHE criteria, returns a `ColGroupDDC` created with the column indexes, an `IdentityDictionary`, and the data.
+ - If the data does not meet the OHE criteria, returns the result of `directCompressDDC`.
diff --git a/scripts/perftest/ohe_checks/experiment.dml b/scripts/perftest/ohe_checks/experiment.dml
new file mode 100644
index 00000000000..c89f7dcaef6
--- /dev/null
+++ b/scripts/perftest/ohe_checks/experiment.dml
@@ -0,0 +1,47 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+## This script generates a random matrix, transforms some columns to be One-Hot-Encoded, and then compresses
+
+# Set default values
+default_rows = 1000
+default_cols = 10
+default_dummy = "[1]"
+default_repeats = 1
+default_num_distinct = 10
+
+#nvargs
+rows = ifdef($rows, default_rows)
+cols = ifdef($cols, default_cols)
+dummy = ifdef($dummy, default_dummy)
+repeats = ifdef($repeats, default_repeats)
+num_distinct = ifdef($distinct, default_num_distinct)
+
+# Generate random matrix and apply transformations
+x = rand(rows=rows, cols=cols, min=0, max=num_distinct)
+x = floor(x)
+Fall = as.frame(x)
+jspec = "{ids: true, dummycode: " + dummy + "}";
+for(i in 1:repeats){
+ [T,M] = transformencode(target=Fall, spec=jspec)
+ xc = compress(T)
+}
+print(toString(xc))
diff --git a/scripts/perftest/ohe_checks/experiments.sh b/scripts/perftest/ohe_checks/experiments.sh
new file mode 100644
index 00000000000..2e6e2bd2786
--- /dev/null
+++ b/scripts/perftest/ohe_checks/experiments.sh
@@ -0,0 +1,61 @@
+#!/usr/bin/env bash
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+mkdir BaselineLogs
+mkdir OHELogs
+run_base() {
+ $SYSTEMDS_ROOT/bin/systemds $SYSTEMDS_ROOT/target/SystemDS.jar experiment.dml \
+ --seed 42 --debug -nvargs rows=$1 cols=$2 dummy="$3" distinct=$4 > BaselineLogs/${5}_${1}_rows_${2}_cols_${3}_encoded_base.txt 2>&1
+}
+
+run_ohe() {
+ $SYSTEMDS_ROOT/bin/systemds $SYSTEMDS_ROOT/target/SystemDS.jar experiment.dml \
+ --seed 42 --debug --config ohe.xml -nvargs rows=$1 cols=$2 dummy="$3" distinct=$4> OHELogs/${5}_${1}_rows_${2}_cols_${3}_encoded_ohe.txt 2>&1
+}
+
+# Run same experiments but checking One-Hot-Encoded columns first
+run_ohe 1000 1 "[1]" 10 1
+run_ohe 1000 5 "[2]" 10 2
+run_ohe 1000 5 "[1,2]" 10 3
+run_ohe 1000 5 "[1,2,3]" 10 4
+run_ohe 1000 5 "[1,2,3,4,5]" 10 5
+run_ohe 1000 10 "[1,3,5]" 10 6
+run_ohe 1000 10 "[1,2,5,6]" 10 7
+run_ohe 100000 1 "[1]" 100 8
+run_ohe 100000 5 "[1,2]" 100 9
+run_ohe 100000 5 "[1,2,3]" 100 10
+run_ohe 100000 100 "[1,3,50,60,70,80]" 100 11
+run_ohe 100000 100 "[1,2,24,25,50,51]" 100 12
+
+# Run baseline experiments
+run_base 1000 1 "[1]" 10 1
+run_base 1000 5 "[2]" 10 2
+run_base 1000 5 "[1,2]" 10 3
+run_base 1000 5 "[1,2,3]" 10 4
+run_base 1000 5 "[1,2,3,4,5]" 10 5
+run_base 1000 10 "[1,3,5]" 10 6
+run_base 1000 10 "[1,2,5,6]" 10 7
+run_base 100000 1 "[1]" 100 8
+run_base 100000 5 "[1,2]" 100 9
+run_base 100000 5 "[1,2,3]" 100 10
+run_base 100000 100 "[1,3,50,60,70,80]" 100 11
+run_base 100000 100 "[1,2,24,25,50,51]" 100 12
diff --git a/scripts/perftest/ohe_checks/log4j-compression.properties b/scripts/perftest/ohe_checks/log4j-compression.properties
new file mode 100644
index 00000000000..5f7c1cd70b5
--- /dev/null
+++ b/scripts/perftest/ohe_checks/log4j-compression.properties
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+log4j.rootLogger=ERROR, console
+
+log4j.logger.org.apache.sysds=INFO
+log4j.logger.org.apache.sysds.runtime.compress=DEBUG
+log4j.logger.org.apache.spark=ERROR
+log4j.logger.org.apache.spark.SparkContext=OFF
+log4j.logger.org.apache.hadoop=ERROR
+
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.target=System.err
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{2}: %m%n
diff --git a/scripts/perftest/ohe_checks/ohe.xml b/scripts/perftest/ohe_checks/ohe.xml
new file mode 100644
index 00000000000..35c3e481a07
--- /dev/null
+++ b/scripts/perftest/ohe_checks/ohe.xml
@@ -0,0 +1,22 @@
+
+
+
+ true
+
diff --git a/scripts/perftest/ohe_checks/parse_logs.py b/scripts/perftest/ohe_checks/parse_logs.py
new file mode 100644
index 00000000000..edd27a6e7e4
--- /dev/null
+++ b/scripts/perftest/ohe_checks/parse_logs.py
@@ -0,0 +1,90 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+import os
+import re
+import pandas as pd
+
+# Patterns to search for
+patterns = {
+ "num_col_groups": re.compile(r"--num col groups:\s+(\d+)"),
+ "compressed_size": re.compile(r"--compressed size:\s+(\d+)"),
+ "compression_ratio": re.compile(r"--compression ratio:\s+([\d.]+)"),
+ "execution_time": re.compile(r"Total execution time:\s+([\d.]+) sec.")
+}
+
+# Function to extract metrics from the text files
+def extract_metrics(file_path):
+ with open(file_path, 'r') as file:
+ content = file.read()
+ metrics = {}
+ for key, pattern in patterns.items():
+ match = pattern.search(content)
+ if match:
+ metrics[key] = match.group(1)
+ else:
+ metrics[key] = None
+ return metrics
+
+# Directories for baseline and OHE
+baseline_dir = "BaselineLogs"
+ohe_dir = "OHELogs"
+
+# Data storage
+data_combined = []
+
+# Process baseline and corresponding OHE files
+for file_name in os.listdir(baseline_dir):
+ if file_name.endswith("_encoded_base.txt"):
+ experiment_name = file_name[:-4] # Remove the .txt extension
+ file_path_baseline = os.path.join(baseline_dir, file_name)
+ metrics_baseline = extract_metrics(file_path_baseline)
+
+ file_name_ohe = file_name.replace('_base.txt', '_ohe.txt').replace('Baseline', 'OHE')
+ file_path_ohe = os.path.join(ohe_dir, file_name_ohe)
+ if os.path.exists(file_path_ohe):
+ metrics_ohe = extract_metrics(file_path_ohe)
+ else:
+ metrics_ohe = {key: None for key in patterns.keys()}
+
+ combined_metrics = {
+ 'experiment': experiment_name,
+ 'baseline_num_col_groups': metrics_baseline.get('num_col_groups'),
+ 'baseline_compressed_size': metrics_baseline.get('compressed_size'),
+ 'baseline_compression_ratio': metrics_baseline.get('compression_ratio'),
+ 'baseline_execution_time': metrics_baseline.get('execution_time'),
+ 'ohe_num_col_groups': metrics_ohe.get('num_col_groups'),
+ 'ohe_compressed_size': metrics_ohe.get('compressed_size'),
+ 'ohe_compression_ratio': metrics_ohe.get('compression_ratio'),
+ 'ohe_execution_time': metrics_ohe.get('execution_time')
+ }
+
+ data_combined.append(combined_metrics)
+
+# Create DataFrame
+df_combined = pd.DataFrame(data_combined)
+
+# Write to Excel
+output_file = 'combined_metrics.xlsx'
+df_combined.to_excel(output_file, index=False)
+
+print(f'Excel file "{output_file}" has been created successfully.')
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index dd4d3b2457f..11b1e060a3c 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -88,6 +88,7 @@ public class DMLConfig
public static final String COMPRESSED_COST_MODEL= "sysds.compressed.costmodel";
public static final String COMPRESSED_TRANSPOSE = "sysds.compressed.transpose";
public static final String COMPRESSED_TRANSFORMENCODE = "sysds.compressed.transformencode";
+ public static final String COMPRESSED_ONEHOTDETECT = "sysds.compressed.onehotdetect";
public static final String NATIVE_BLAS = "sysds.native.blas";
public static final String NATIVE_BLAS_DIR = "sysds.native.blas.directory";
public static final String DAG_LINEARIZATION = "sysds.compile.linearization";
@@ -173,6 +174,7 @@ public class DMLConfig
_defaultVals.put(COMPRESSED_COST_MODEL, "AUTO");
_defaultVals.put(COMPRESSED_TRANSPOSE, "auto");
_defaultVals.put(COMPRESSED_TRANSFORMENCODE, "false");
+ _defaultVals.put(COMPRESSED_ONEHOTDETECT, "false");
_defaultVals.put(DAG_LINEARIZATION, DagLinearizer.DEPTH_FIRST.name());
_defaultVals.put(CODEGEN, "false" );
_defaultVals.put(CODEGEN_API, GeneratorAPI.JAVA.name() );
@@ -458,7 +460,7 @@ public String getConfigInfo() {
CP_PARALLEL_OPS, CP_PARALLEL_IO, PARALLEL_ENCODE, NATIVE_BLAS, NATIVE_BLAS_DIR,
COMPRESSED_LINALG, COMPRESSED_LOSSY, COMPRESSED_VALID_COMPRESSIONS, COMPRESSED_OVERLAPPING,
COMPRESSED_SAMPLING_RATIO, COMPRESSED_SOFT_REFERENCE_COUNT,
- COMPRESSED_COCODE, COMPRESSED_TRANSPOSE, COMPRESSED_TRANSFORMENCODE, DAG_LINEARIZATION,
+ COMPRESSED_COCODE, COMPRESSED_TRANSPOSE, COMPRESSED_TRANSFORMENCODE, COMPRESSED_ONEHOTDETECT, DAG_LINEARIZATION,
CODEGEN, CODEGEN_API, CODEGEN_COMPILER, CODEGEN_OPTIMIZER, CODEGEN_PLANCACHE, CODEGEN_LITERALS,
STATS_MAX_WRAP_LEN, LINEAGECACHESPILL, COMPILERASSISTED_RW, BUFFERPOOL_LIMIT, MEMORY_MANAGER,
PRINT_GPU_MEMORY_INFO, AVAILABLE_GPUS, SYNCHRONIZE_GPU, EAGER_CUDA_FREE, GPU_RULE_BASED_PLACEMENT,
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
index b31ef4afddf..9aa934dfe7b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
@@ -285,6 +285,7 @@ else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOv
_stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0);
_stats.sparseSize = MatrixBlock.estimateSizeSparseInMemory(mb.getNumRows(), mb.getNumColumns(), mb.getSparsity());
+ _stats.sparsity = mb.getSparsity();
_stats.originalSize = mb.getInMemorySize();
_stats.originalCost = costEstimator.getCost(mb);
@@ -522,6 +523,7 @@ private void logPhase() {
LOG.debug("--col groups sizes " + _stats.getGroupsSizesString());
LOG.debug("--input was compressed " + (mb instanceof CompressedMatrixBlock));
LOG.debug(String.format("--dense size: %16d", _stats.denseSize));
+ LOG.debug(String.format("--sparsity: %4.3f", _stats.sparsity));
LOG.debug(String.format("--sparse size: %16d", _stats.sparseSize));
LOG.debug(String.format("--original size: %16d", _stats.originalSize));
LOG.debug(String.format("--compressed size: %16d", _stats.compressedSize));
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
index 062ccfc1201..4fedbabe777 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
@@ -128,11 +128,14 @@ public class CompressionSettings {
/** The sorting type used in sorting/joining offsets to create SDC groups */
public final SORT_TYPE sdcSortType;
+ /** Flag to detect one hot encodedness */
+ public final boolean oneHotDetect;
+
protected CompressionSettings(double samplingRatio, double samplePower, boolean allowSharedDictionary,
String transposeInput, int seed, boolean lossy, EnumSet validCompressions,
boolean sortValuesByLength, PartitionerType columnPartitioner, int maxColGroupCoCode, double coCodePercentage,
int minimumSampleSize, int maxSampleSize, EstimationType estimationType, CostType costComputationType,
- double minimumCompressionRatio, boolean isInSparkInstruction, SORT_TYPE sdcSortType) {
+ double minimumCompressionRatio, boolean isInSparkInstruction, SORT_TYPE sdcSortType, boolean oneHotDetect) {
this.samplingRatio = samplingRatio;
this.samplePower = samplePower;
this.allowSharedDictionary = allowSharedDictionary;
@@ -151,6 +154,7 @@ protected CompressionSettings(double samplingRatio, double samplePower, boolean
this.minimumCompressionRatio = minimumCompressionRatio;
this.isInSparkInstruction = isInSparkInstruction;
this.sdcSortType = sdcSortType;
+ this.oneHotDetect = oneHotDetect;
if(LOG.isDebugEnabled())
LOG.debug(this.toString());
}
@@ -168,6 +172,7 @@ public String toString() {
sb.append("\t Partitioner: " + columnPartitioner);
sb.append("\t Lossy: " + lossy);
sb.append("\t Cost Computation Type: " + costComputationType);
+ sb.append("\t One Hot Encoding Check Flag: " + oneHotDetect);
if(samplingRatio < 1.0)
sb.append("\t Estimation Type: " + estimationType);
return sb.toString();
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
index ec5512266e8..54f6ee07867 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
@@ -55,6 +55,7 @@ public class CompressionSettingsBuilder {
private double minimumCompressionRatio = 1.0;
private boolean isInSparkInstruction = false;
private SORT_TYPE sdcSortType = SORT_TYPE.MATERIALIZE;
+ private boolean oneHotDetect = false;
public CompressionSettingsBuilder() {
@@ -69,6 +70,7 @@ public CompressionSettingsBuilder() {
costType = CostType.valueOf(conf.getTextValue(DMLConfig.COMPRESSED_COST_MODEL));
transposeInput = conf.getTextValue(DMLConfig.COMPRESSED_TRANSPOSE);
seed = DMLScript.SEED;
+ oneHotDetect = conf.getBooleanValue(DMLConfig.COMPRESSED_ONEHOTDETECT);
}
@@ -90,6 +92,7 @@ public CompressionSettingsBuilder copySettings(CompressionSettings that) {
this.maxColGroupCoCode = that.maxColGroupCoCode;
this.coCodePercentage = that.coCodePercentage;
this.minimumSampleSize = that.minimumSampleSize;
+ this.oneHotDetect = that.oneHotDetect;
return this;
}
@@ -170,6 +173,17 @@ public CompressionSettingsBuilder setSeed(int seed) {
return this;
}
+ /**
+ * Set the flag for detecting one hot encodedness for the compression operation.
+ *
+ * @param enabled The flag to enable or disable OHE checks.
+ * @return The CompressionSettingsBuilder
+ */
+ public CompressionSettingsBuilder setOneHotDetect(boolean enabled) {
+ this.oneHotDetect = enabled;
+ return this;
+ }
+
/**
* Set the valid compression strategies used for the compression.
*
@@ -334,6 +348,6 @@ public CompressionSettings create() {
return new CompressionSettings(samplingRatio, samplePower, allowSharedDictionary, transposeInput, seed, lossy,
validCompressions, sortValuesByLength, columnPartitioner, maxColGroupCoCode, coCodePercentage,
minimumSampleSize, maxSampleSize, estimationType, costType, minimumCompressionRatio, isInSparkInstruction,
- sdcSortType);
+ sdcSortType, oneHotDetect);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionStatistics.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionStatistics.java
index d54eb2c3525..9708e0b19f7 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressionStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionStatistics.java
@@ -46,6 +46,8 @@ public class CompressionStatistics {
/** Compressed size */
public long compressedSize;
+ /** Sparsity of input matrix */
+ public double sparsity;
/** Cost calculated by the cost estimator on input */
public double originalCost = Double.NaN;
/** Summed cost estimated from individual columns */
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
index 6c560fb9792..4ed7e0fbe25 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
@@ -25,6 +25,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
@@ -62,13 +63,14 @@ public static CompressedSizeInfo findCoCodesByPartitioning(AComEst est, Compress
// Use column group partitioner to create partitions of columns
AColumnCoCoder co = createColumnGroupPartitioner(cs.columnPartitioner, est, costEstimator, cs);
- // Find out if any of the groups are empty.
- final boolean containsEmptyConstOrIncompressable = containsEmptyConstOrIncompressable(colInfos);
+ boolean containsEmptyConstOrIncompressable = containsEmptyConstOrIncompressable(colInfos);
- // if there are no empty or const columns then try cocode algorithms for all columns
- if(!containsEmptyConstOrIncompressable)
+ // If no empty, constant, incompressible groups and not OHE, cocode all columns
+ if(!containsEmptyConstOrIncompressable && !cs.oneHotDetect) {
return co.coCodeColumns(colInfos, k);
+ }
else {
+
// filtered empty groups
final List emptyCols = new ArrayList<>();
// filtered const groups
@@ -81,6 +83,12 @@ public static CompressedSizeInfo findCoCodesByPartitioning(AComEst est, Compress
final int nRow = colInfos.compressionInfo.get(0).getNumRows();
// filter groups
+ List currentCandidates = new ArrayList<>();
+ List> oheGroups = new ArrayList<>();
+ boolean isSample = est.getClass().getSimpleName().equals("ComEstSample");
+ if(est.getNnzCols() == null)
+ LOG.debug("NNZ is null");
+ int[] nnzCols = est.getNnzCols();
for(int i = 0; i < colInfos.compressionInfo.size(); i++) {
CompressedSizeInfoColGroup g = colInfos.compressionInfo.get(i);
if(g.isEmpty())
@@ -89,14 +97,54 @@ else if(g.isConst())
constCols.add(g.getColumns());
else if(g.isIncompressable())
incompressable.add(g.getColumns());
- else
+ else if(isCandidate(g)) {
+ currentCandidates.add(g);
+ String oheStatus = isHotEncoded(currentCandidates, isSample, nnzCols, nRow);
+ if(oheStatus.equals("NOT_OHE")) {
+ groups.addAll(currentCandidates);
+ currentCandidates.clear();
+ }
+ else if(oheStatus.equals("VALID_OHE")) {
+ LOG.debug("FOUND OHE");
+ oheGroups.add(new ArrayList<>(currentCandidates));
+ currentCandidates.clear();
+ }
+ }
+ else {
groups.add(g);
+ if(!currentCandidates.isEmpty()) {
+ for(CompressedSizeInfoColGroup gg : currentCandidates)
+ groups.add(gg);
+ currentCandidates.clear();
+ }
+ }
+ }
+
+ // If currentCandidates is not empty, add it to groups
+ if(!currentCandidates.isEmpty()) {
+ for(CompressedSizeInfoColGroup gg : currentCandidates) {
+ groups.add(gg);
+ }
+ currentCandidates.clear();
}
// overwrite groups.
colInfos.compressionInfo = groups;
+ for(List oheGroup : oheGroups) {
+ final List oheIndexes = new ArrayList<>();
+ for(CompressedSizeInfoColGroup g : oheGroup) {
+ oheIndexes.add(g.getColumns());
+ }
+ final IColIndex idx = ColIndexFactory.combineIndexes(oheIndexes);
+ groups.add(new CompressedSizeInfoColGroup(idx, nRow, CompressionType.OHE));
+ }
+
// cocode remaining groups
+ if(colInfos.getInfo().size() <= 0 && incompressable.size() <= 0 && emptyCols.size() <= 0 &&
+ constCols.size() == 0 && oheGroups.size() <= 0)
+ throw new DMLCompressionException("empty cocoders 1");
+
if(!groups.isEmpty()) {
colInfos = co.coCodeColumns(colInfos, k);
}
@@ -118,9 +166,13 @@ else if(g.isIncompressable())
colInfos.compressionInfo.add(new CompressedSizeInfoColGroup(idx, nRow, CompressionType.UNCOMPRESSED));
}
+ if(colInfos.getInfo().size() <= 0)
+ throw new DMLCompressionException("empty cocoders 2");
+
return colInfos;
}
+
}
private static boolean containsEmptyConstOrIncompressable(CompressedSizeInfo colInfos) {
@@ -130,6 +182,47 @@ private static boolean containsEmptyConstOrIncompressable(CompressedSizeInfo col
return false;
}
+ private static boolean isCandidate(CompressedSizeInfoColGroup g) {
+ // Check if the column has exactly 2 distinct value other than 0
+ return(g.getNumVals() == 2);
+ }
+
+ private static String isHotEncoded(List colGroups, boolean isSample, int[] nnzCols,
+ int numRows) {
+ if(colGroups.isEmpty()) {
+ return "NOT_OHE";
+ }
+
+ int numCols = colGroups.size();
+ int totalNumVals = 0;
+ int totalNumOffs = 0;
+
+ for(int i = 0; i < colGroups.size(); i++) {
+ CompressedSizeInfoColGroup g = colGroups.get(i);
+ totalNumVals += g.getNumVals();
+ if(totalNumVals / 2 > numCols)
+ return "NOT_OHE";
+ // If sampling is used, get the number of non-zeroes from the nnzCols array
+ if(isSample && nnzCols != null) {
+ totalNumOffs += nnzCols[i];
+ }
+ else {
+ totalNumOffs += g.getNumOffs();
+ }
+ if(totalNumOffs > numRows) {
+ return "NOT_OHE";
+ }
+ }
+
+ // Check if the current candidates form a valid OHE group
+ if((totalNumVals / 2) == numCols && totalNumOffs == numRows) {
+ return "VALID_OHE";
+ }
+
+ // If still under the row limit, it's potentially OHE
+ return "POTENTIAL_OHE";
+ }
+
private static AColumnCoCoder createColumnGroupPartitioner(PartitionerType type, AComEst est,
ACostEstimate costEstimator, CompressionSettings cs) {
switch(type) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
index a4030d95612..a607424f4c2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
@@ -57,7 +57,7 @@ public abstract class AColGroup implements Serializable {
/** Public super types of compression ColGroups supported */
public static enum CompressionType {
- UNCOMPRESSED, RLE, OLE, DDC, CONST, EMPTY, SDC, SDCFOR, DDCFOR, DeltaDDC, LinearFunctional;
+ UNCOMPRESSED, RLE, OLE, DDC, CONST, EMPTY, SDC, SDCFOR, DDCFOR, DeltaDDC, LinearFunctional, OHE;
public boolean isDense() {
return this == DDC || this == CONST || this == DDCFOR || this == DDCFOR;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index f382f7b2f71..ec1d6beefa5 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -40,6 +40,7 @@
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.functional.LinearRegression;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
@@ -226,12 +227,12 @@ private void logEstVsActual(double time, AColGroup act, CompressedSizeInfoColGro
if(estC < actC * 0.75) {
String warning = "The estimate cost is significantly off : " + est;
LOG.debug(
- String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s\n\t\t%s", time,
- retType, estC, actC, act.getNumValues(), cols, wanted, warning));
+ String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s\n\t\t%s",
+ time, retType, estC, actC, act.getNumValues(), cols, wanted, warning));
}
else {
- LOG.debug(String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s", time,
- retType, estC, actC, act.getNumValues(), cols, wanted));
+ LOG.debug(String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s",
+ time, retType, estC, actC, act.getNumValues(), cols, wanted));
}
}
@@ -282,6 +283,11 @@ else if(ct == CompressionType.SDC && colIndexes.size() == 1 && !t) {
return compressSDCSingleColDirectBlock(colIndexes, cg.getNumVals());
}
+ else if(ct == CompressionType.OHE) {
+ boolean isSample = ce.getClass().getSimpleName().equals("ComEstSample");
+ return compressOHE(colIndexes, cg, isSample);
+ }
+
final ABitmap ubm = BitmapEncoder.extractBitmap(colIndexes, in, cg.getNumVals(), cs);
if(ubm == null) // no values ... therefore empty
return new ColGroupEmpty(colIndexes);
@@ -312,6 +318,69 @@ else if(ct == CompressionType.SDC && colIndexes.size() == 1 && !t) {
}
}
+ private AColGroup compressOHE(IColIndex colIndexes, CompressedSizeInfoColGroup cg, boolean isSample)
+ throws Exception {
+ // Ensure numVals is valid
+ int numVals = cg.getNumVals();
+ if(numVals <= 0) {
+ throw new DMLCompressionException("Number of values must be greater than 0 for one-hot encoding");
+ }
+
+ AMapToData data;
+
+ if(cs.transposed) {
+ // Handle transposed matrix
+ data = MapToFactory.create(cs.transposed ? in.getNumColumns() : in.getNumRows(), numVals + 1);
+ SparseBlock sb = in.getSparseBlock();
+ data.fill(numVals + 1);
+ for(int c = 0; c < colIndexes.size(); c++) {
+ int cidx = colIndexes.get(c);
+ if(sb.isEmpty(cidx)) {
+ return directCompressDDC(colIndexes, cg);
+ }
+ final int apos = sb.pos(cidx);
+ final int alen = sb.size(cidx) + apos;
+ final int[] aix = sb.indexes(cidx);
+ final double[] aval = sb.values(cidx);
+ for(int k = apos; k < alen; k++) {
+ if((aval[k] != 1) || data.getIndex(aix[k]) != (numVals + 1))
+ return directCompressDDC(colIndexes, cg);
+
+ else
+ data.set(aix[k], c);
+
+ }
+
+ }
+ data.setUnique(numVals);
+ }
+ else {
+ // Handle non-transposed matrix
+ data = MapToFactory.create(cs.transposed ? in.getNumColumns() : in.getNumRows(), numVals);
+ for(int r = 0; r < in.getNumRows(); r++) {
+ boolean foundOne = false;
+ for(int c = 0; c < colIndexes.size(); c++) {
+ if(in.get(r, colIndexes.get(c)) == 1) {
+ if(foundOne) {
+ // If another '1' is found in the same row, fall back to directCompressDDC
+ LOG.info("Rigorous check showed that it's not OHE");
+ return directCompressDDC(colIndexes, cg);
+ }
+ data.set(r, c);
+ foundOne = true;
+ }
+ }
+ if(isSample && !foundOne) {
+ // If it's a sample and no '1' is found in the row, fall back to directCompressDDC
+ LOG.info("Rigorous check showed that it's not OHE");
+ return directCompressDDC(colIndexes, cg);
+ }
+ }
+ }
+
+ return ColGroupDDC.create(colIndexes, new IdentityDictionary(numVals), data, null);
+ }
+
private AColGroup compressSDCSingleColDirectBlock(IColIndex colIndexes, int nVal) {
final DoubleCountHashMap cMap = new DoubleCountHashMap(nVal);
final int col = colIndexes.get(0);
@@ -482,7 +551,7 @@ private AColGroup directCompressDDCMultiCol(IColIndex colIndexes, CompressedSize
final DblArrayCountHashMap map = new DblArrayCountHashMap(Math.max(cg.getNumVals(), 64));
boolean extra;
- if(nRow < CompressionSettings.PAR_DDC_THRESHOLD || k < csi.getNumberColGroups() || pool == null )
+ if(nRow < CompressionSettings.PAR_DDC_THRESHOLD || k < csi.getNumberColGroups() || pool == null)
extra = readToMapDDC(colIndexes, map, d, 0, nRow, fill);
else
extra = parallelReadToMapDDC(colIndexes, map, d, nRow, fill, k);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java
index 99bf5cbb8f5..227d6ded3dd 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java
@@ -23,6 +23,7 @@
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
@@ -129,4 +130,13 @@ public static long estimateInMemorySizeLinearFunctional(int nrColumns, boolean c
size += 4; // _numRows
return size;
}
+
+ // New method to estimate in-memory size for one-hot encoded columns
+ public static long estimateInMemorySizeOHE(int nrColumns, boolean contiguousColumns, int nrRows) {
+ long size = estimateInMemorySizeGroup(nrColumns, contiguousColumns);
+ // OHE specific estimations
+ size += IdentityDictionary.getInMemorySize(nrColumns); // Dictionary for unique values
+ size += MapToFactory.estimateInMemorySize(nrRows, nrColumns); // Mapping for rows to unique values
+ return size;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java b/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java
index 832725f328f..cf1ce6cc178 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java
@@ -71,6 +71,10 @@ public int getNumRows() {
return _cs.transposed ? _data.getNumColumns() : _data.getNumRows();
}
+ public int[] getNnzCols(){
+ return nnzCols;
+ }
+
/**
* Get the number of columns in the overall compressing block.
*
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java
index 6306b04e8c1..42fef705e68 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java
@@ -160,7 +160,7 @@ private int distinctCountScale(EstimationFactors sampleFacts, int numOffs, int n
// the frequencies of non empty entries.
final int[] freq = sampleFacts.frequencies;
if(freq == null || freq.length == 0)
- return numOffs; // very aggressive number of distinct
+ return numOffs+1; // very aggressive number of distinct
// sampled size is smaller than actual if there was empty rows.
// and the more we can reduce this value the more accurate the estimation will become.
final int sampledSize = sampleFacts.numOffs;
@@ -171,6 +171,8 @@ private int distinctCountScale(EstimationFactors sampleFacts, int numOffs, int n
if(nCol > 4) // Increase estimate if we get into many columns cocoding to be safe
est += ((double) est) * ((double) nCol) / 10;
// Bound the estimate with the maxDistinct.
+ if(sampleFacts.numRows>sampleFacts.numOffs)
+ est += 1;
return Math.max(Math.min(est, Math.min(maxDistinct, numOffs)), 1);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfo.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfo.java
index fe87057f67a..ed2018a7b95 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfo.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfo.java
@@ -25,6 +25,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
/**
* A helper reusable object for maintaining information about estimated compression
@@ -95,6 +96,9 @@ public String getEstimatedDistinct() {
StringBuilder sb = new StringBuilder();
if(compressionInfo == null)
return "";
+ if(compressionInfo.size()<=0)
+ throw new DMLCompressionException("Size of compression info is <= 0");
+
sb.append("[");
sb.append(compressionInfo.get(0).getNumVals());
for(int i = 1; i < compressionInfo.size(); i++)
@@ -107,6 +111,8 @@ public String getNrColumnsString() {
StringBuilder sb = new StringBuilder();
if(compressionInfo == null)
return "";
+ if(compressionInfo.size()<=0)
+ return "";
sb.append("[");
sb.append(compressionInfo.get(0).getColumns().size());
for(int i = 1; i < compressionInfo.size(); i++)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
index 4fbf9b0ee4d..0a9ef064040 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
@@ -119,20 +119,26 @@ public CompressedSizeInfoColGroup(IColIndex columns, EstimationFactors facts,
*/
public CompressedSizeInfoColGroup(IColIndex columns, int nRows, CompressionType ct) {
_cols = columns;
- _facts = new EstimationFactors(0, nRows);
_sizes = new EnumMap<>(CompressionType.class);
switch(ct) {
case EMPTY:
+ _facts = new EstimationFactors(1, nRows);
_sizes.put(ct, (double) ColGroupSizes.estimateInMemorySizeEMPTY(columns.size(), columns.isContiguous()));
break;
case CONST:
+ _facts = new EstimationFactors(1, nRows);
_sizes.put(ct,
(double) ColGroupSizes.estimateInMemorySizeCONST(columns.size(), columns.isContiguous(), 1.0, false));
break;
case UNCOMPRESSED:
+ _facts = new EstimationFactors(nRows, nRows);
_sizes.put(ct, (double) ColGroupSizes.estimateInMemorySizeUncompressed(nRows, columns.isContiguous(),
columns.size(), 1.0));
break;
+ case OHE:
+ _facts = new EstimationFactors(columns.size(), nRows);
+ _sizes.put(ct, (double) ColGroupSizes.estimateInMemorySizeOHE(columns.size(), columns.isContiguous(), nRows));
+ break;
default:
throw new DMLCompressionException("Invalid instantiation of const Cost");
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
index ffe365127af..69ea2bde175 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
@@ -391,10 +391,10 @@ public EstimationFactors extractFacts(int nRows, double tupleSparsity, double ma
final int[] counts = map.getCounts();
if(cs.isRLEAllowed())
- return new EstimationFactors(map.getUnique(), map.size(), largestOffs, counts, 0, nRows, map.countRuns(off),
+ return new EstimationFactors(getUnique(), map.size(), largestOffs, counts, 0, nRows, map.countRuns(off),
false, true, matrixSparsity, tupleSparsity);
else
- return new EstimationFactors(map.getUnique(), map.size(), largestOffs, counts, 0, nRows, false, true,
+ return new EstimationFactors(getUnique(), map.size(), largestOffs, counts, 0, nRows, false, true,
matrixSparsity, tupleSparsity);
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index c514e010a72..fb364aecfbd 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -630,6 +630,7 @@ private void outputMatrixPostProcessing(MatrixBlock output, int k){
}
else {
output.recomputeNonZeros(k);
+ output.examSparsity(k);
}