From 5ce228d1f3f241ad0b048e6d697e16314ef68c66 Mon Sep 17 00:00:00 2001 From: Yunhui Zhang Date: Wed, 27 Apr 2022 09:50:45 -0400 Subject: [PATCH] Upgrade ChiSqSelectorModel to spark 3.2.0 compatable design sort filterIndiecs before using it --- .../mleap/core/feature/ChiSqSelectorModel.scala | 13 +++++++------ .../mleap/core/feature/ChiSqSelectorModelSpec.scala | 13 ++++++++++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/ChiSqSelectorModel.scala b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/ChiSqSelectorModel.scala index 20ac8dd5a..fa5709428 100644 --- a/mleap-core/src/main/scala/ml/combust/mleap/core/feature/ChiSqSelectorModel.scala +++ b/mleap-core/src/main/scala/ml/combust/mleap/core/feature/ChiSqSelectorModel.scala @@ -10,22 +10,23 @@ import scala.collection.mutable /** * Created by hollinwilkins on 12/27/16. */ -@SparkCode(uri = "https://github.com/apache/spark/blob/v2.0.0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala") +@SparkCode(uri = "https://github.com/apache/spark/blob/v3.2.0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala") case class ChiSqSelectorModel(filterIndices: Seq[Int], inputSize: Int) extends Model { + private val sortedFilterIndices = filterIndices.sorted def apply(features: Vector): Vector = { features match { case SparseVector(size, indices, values) => - val newSize = filterIndices.length + val newSize = sortedFilterIndices.length val newValues = mutable.ArrayBuilder.make[Double] val newIndices = mutable.ArrayBuilder.make[Int] var i = 0 var j = 0 var indicesIdx = 0 var filterIndicesIdx = 0 - while (i < indices.length && j < filterIndices.length) { + while (i < indices.length && j < sortedFilterIndices.length) { indicesIdx = indices(i) - filterIndicesIdx = filterIndices(j) + filterIndicesIdx = sortedFilterIndices(j) if (indicesIdx == filterIndicesIdx) { newIndices += j newValues += values(i) @@ -43,7 +44,7 @@ case class ChiSqSelectorModel(filterIndices: Seq[Int], Vectors.sparse(newSize, newIndices.result(), newValues.result()) case DenseVector(values) => val values = features.toArray - Vectors.dense(filterIndices.map(i => values(i)).toArray) + Vectors.dense(sortedFilterIndices.map(i => values(i)).toArray) case other => throw new UnsupportedOperationException( s"Only sparse and dense vectors are supported but got ${other.getClass}.") @@ -52,5 +53,5 @@ case class ChiSqSelectorModel(filterIndices: Seq[Int], override def inputSchema: StructType = StructType("input" -> TensorType.Double(inputSize)).get - override def outputSchema: StructType = StructType("output" -> TensorType.Double(filterIndices.length)).get + override def outputSchema: StructType = StructType("output" -> TensorType.Double(sortedFilterIndices.length)).get } diff --git a/mleap-core/src/test/scala/ml/combust/mleap/core/feature/ChiSqSelectorModelSpec.scala b/mleap-core/src/test/scala/ml/combust/mleap/core/feature/ChiSqSelectorModelSpec.scala index 9ee6234a1..bce04efdc 100644 --- a/mleap-core/src/test/scala/ml/combust/mleap/core/feature/ChiSqSelectorModelSpec.scala +++ b/mleap-core/src/test/scala/ml/combust/mleap/core/feature/ChiSqSelectorModelSpec.scala @@ -2,11 +2,22 @@ package ml.combust.mleap.core.feature import ml.combust.mleap.core.types.{StructField, TensorType} import org.scalatest.FunSpec +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} class ChiSqSelectorModelSpec extends FunSpec { describe("input/output schema"){ - val model = new ChiSqSelectorModel(Seq(1,2,3), 3) + val model = new ChiSqSelectorModel(Seq(2,3, 1), 3) + + it("Dense vector work with unsorted indices") { + val vector = Vectors.dense(1.0,2.0,3.0,4.0) + assert(model(vector) == Vectors.dense(2.0, 3.0, 4.0)) + } + + it("Sparse vector work with unsorted indices") { + val vector = Vectors.sparse(size = 4, indices=Array(0,1,2,3), values = Array(1.0,2.0,3.0,4.0)) + assert(model(vector) == Vectors.sparse(size=3, indices=Array(0,1,2), values=Array(2.0,3.0,4.0))) + } it("Has the right input schema") { assert(model.inputSchema.fields ==