Beginners Guide: Apache Spark Machine Learning with Large Data

This informative tutorial walks us through using Spark's machine learning capabilities and Scala to train a logistic regression classifier on a larger-than-memory dataset.

6. Preparing training and testing datasets

The next step – creating binary labels for a binary classifier. For this code examples, we are using “java” as a label that we would like to predict by a binary classifier. All rows with the “java” label should be marked as a “1” and rows with no “java” as a “0”. Let’s identify our target tag “java” and create binary labels based on this tag.

val targetTag = "java"
val myudf: (String => Double) = (str: String) => 
    {if (str.contains(targetTag)) 1.0 else 0.0}
val sqlfunc = udf(myudf)
val postsLabeled = postsDf.withColumn("Label", 
    sqlfunc(col("Tags")) )

Dataset can be split into negative and positive subsets by using the new label.

val positive = postsLabeled.filter('Label > 0.0)
val negative = postsLabeled.filter('Label < 1.0)

We are going to use 90% of our data for the model training and 10% as a testing dataset. Let’s create a training dataset by sampling the positive and negative datasets separately.

val positiveTrain = positive.sample(false, 0.9)
val negativeTrain = negative.sample(false, 0.9)
val training = positiveTrain.unionAll(negativeTrain)

The testing dataset should include all rows which are not included in the training datasets. And again – positive and negative examples separately.

val negativeTrainTmp = negativeTrain
    .withColumnRenamed("Label", "Flag").select('Id, 'Flag)

val negativeTest = negative.join( negativeTrainTmp, 
    negative("Id") === negativeTrainTmp("Id"), 
    "LeftOuter").filter("Flag is null")
    .select(negative("Id"), 'Tags, 'Text, 'Label)

val positiveTrainTmp = positiveTrain
    .withColumnRenamed("Label", "Flag")
    .select('Id, 'Flag)

val positiveTest = positive.join( positiveTrainTmp, 
    positive("Id") === positiveTrainTmp("Id"), 
    "LeftOuter").filter("Flag is null")
    .select(positive("Id"), 'Tags, 'Text, 'Label)

val testing = negativeTest.unionAll(positiveTest)

7. Training a model

Let’s identify training parameters:

  1. Number of features
  2. Regression parameters
  3. Number of epoch for gradient decent
Spark API creates a model based on columns from the data-frame and the training parameters:

val numFeatures = 64000
val numEpochs = 30
val regParam = 0.02

val tokenizer = new Tokenizer().setInputCol("Text")

val hashingTF = new


val lr = new LogisticRegression().setMaxIter(numEpochs)

val pipeline = new Pipeline()
    .setStages(Array(tokenizer, hashingTF, lr))

val model =

8. Testing a model

This is our final code for the binary “Java” classifier which returns a prediction (0.0 or 1.0):

val testTitle = 
 "Easiest way to merge a release into one JAR file"

val testBoby = 
 """Is there a tool or script which easily merges a bunch 
 of href="
 %29" JAR files into one JAR file? A bonus would be to 
 easily set the main-file manifest and make it executable.
 I would like to run it with something like: As far as I 
 can tell, it has no dependencies which indicates that it 
 shouldn't be an easy single-file tool, but the downloaded
 ZIP file contains a lot of libraries."""

val testText = testTitle + testBody

val testDF = sqlContext
   .createDataFrame(Seq( (99.0, testText)))
   .toDF("Label", "Text")

val result = model.transform(testDF)

val prediction = result.collect()(0)(6)

print("Prediction: "+ prediction)

Let’s evaluate the quality of the model based on training dataset.

val testingResult = model.transform(testing)

val testingResultScores = testingResult
   .select("Prediction", "Label").rdd
   .map(r => (r(0).asInstanceOf[Double], r(1)

val bc = 
   new BinaryClassificationMetrics(testingResultScores)

val roc = bc.areaUnderROC

print("Area under the ROC:" + roc)

If you use the small dataset then the quality of your model is probably not the best. Area under the ROC value will be very low (close to 50%) which indicates a poor quality of the model. With an entire Posts.xml dataset, the quality is no so bad. Area under the ROC is 0.64. Probably you can improve this result by playing with different transformations such as TF-IDF and normalization. Not in this blog post.


Apache Spark could be a great option for data processing and for machine learning scenarios if your dataset is larger than your computer memory can hold. It might not be easy to use Spark in a cluster mode within the Hadoop Yarn environment. However, in a local (or standalone) mode, Spark is as simple as any other analytical tool.

Dmitry Petrov Please let me know if you encountered any problem or had future questions. I would really like to hear your feedback.

Bio: Dmitry Petrov, Ph.D. is a Data Scientist at Microsoft. He previously was a Researcher at a university.