Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade ChiSqSelectorModel to spark 3.2.0 compatible design #815

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@ import scala.collection.mutable
/**
* Created by hollinwilkins on 12/27/16.
*/
@SparkCode(uri = "https://github.com/apache/spark/blob/v2.0.0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala")
@SparkCode(uri = "https://github.com/apache/spark/blob/v3.2.0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala")
case class ChiSqSelectorModel(filterIndices: Seq[Int],
inputSize: Int) extends Model {
private val sortedFilterIndices = filterIndices.sorted
def apply(features: Vector): Vector = {
features match {
case SparseVector(size, indices, values) =>
val newSize = filterIndices.length
val newSize = sortedFilterIndices.length
val newValues = mutable.ArrayBuilder.make[Double]
val newIndices = mutable.ArrayBuilder.make[Int]
var i = 0
var j = 0
var indicesIdx = 0
var filterIndicesIdx = 0
while (i < indices.length && j < filterIndices.length) {
while (i < indices.length && j < sortedFilterIndices.length) {
indicesIdx = indices(i)
filterIndicesIdx = filterIndices(j)
filterIndicesIdx = sortedFilterIndices(j)
if (indicesIdx == filterIndicesIdx) {
newIndices += j
newValues += values(i)
Expand All @@ -43,7 +44,7 @@ case class ChiSqSelectorModel(filterIndices: Seq[Int],
Vectors.sparse(newSize, newIndices.result(), newValues.result())
case DenseVector(values) =>
val values = features.toArray
Vectors.dense(filterIndices.map(i => values(i)).toArray)
Vectors.dense(sortedFilterIndices.map(i => values(i)).toArray)
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
Expand All @@ -52,5 +53,5 @@ case class ChiSqSelectorModel(filterIndices: Seq[Int],

override def inputSchema: StructType = StructType("input" -> TensorType.Double(inputSize)).get

override def outputSchema: StructType = StructType("output" -> TensorType.Double(filterIndices.length)).get
override def outputSchema: StructType = StructType("output" -> TensorType.Double(sortedFilterIndices.length)).get
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,22 @@ package ml.combust.mleap.core.feature

import ml.combust.mleap.core.types.{StructField, TensorType}
import org.scalatest.FunSpec
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}

class ChiSqSelectorModelSpec extends FunSpec {

describe("input/output schema"){
val model = new ChiSqSelectorModel(Seq(1,2,3), 3)
val model = new ChiSqSelectorModel(Seq(2,3, 1), 3)

it("Dense vector work with unsorted indices") {
val vector = Vectors.dense(1.0,2.0,3.0,4.0)
assert(model(vector) == Vectors.dense(2.0, 3.0, 4.0))
}

it("Sparse vector work with unsorted indices") {
val vector = Vectors.sparse(size = 4, indices=Array(0,1,2,3), values = Array(1.0,2.0,3.0,4.0))
assert(model(vector) == Vectors.sparse(size=3, indices=Array(0,1,2), values=Array(2.0,3.0,4.0)))
}

it("Has the right input schema") {
assert(model.inputSchema.fields ==
Expand Down