読者です 読者をやめる 読者になる 読者になる

KZKY memo

自分用メモ.

Spark MLlib SVMを使ってみた

MLlibのSVM

  • SVM (L2-regularizer)
  • SVM (L1-regularizer)

の2通りがあるよう.ただし,lossはL1-hinge のみのよう.

Sample Data Retreival

$ git clone https://github.com/apache/incubator-spark.git
$ cd incubator-spark/data/*

にサンプルデータがある.
binary-classificationはlabel={0, 1}のフォーマットのようなのでなので,
label={-1,1}になっているlr_data.txtの-1を0に変換する.

SVMSample (L2-norm)

package edu.kzk.spark_sample.mllib

import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint

object SVMSample {

    def main(args: Array[String]) {

        val sc = new SparkContext("local", "SVM sample",
                "/opt/cloudera/parcels/SPARK/");

        // Load and parse the data file
        val file = "/home/kzk/datasets/spark-sample/svm_data.txt";
        val data = sc.textFile(file);
        val parsedData = data.map { line =>
        val parts = line.split(' ')
        LabeledPoint(parts(0).toDouble, parts.tail.map(x => x.toDouble).toArray)
        }

        // Run training algorithm to build the model
        val numIterations = 20;
        val model = SVMWithSGD.train(parsedData, numIterations);

        // Evaluate model on training examples and compute training error
        val labelAndPreds = parsedData.map { point =>
        val prediction = model.predict(point.features)
        (point.label, prediction)
        }
        val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / parsedData.count;
        println("Training Error = " + trainErr)      

    }
}

SVMSample (L1-norm)

package edu.kzk.spark_sample.mllib

import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.optimization.L1Updater

object SVML1Sample {
    def main(args: Array[String]) {

        val sc = new SparkContext("local", "SVM sample",
                "/opt/cloudera/parcels/SPARK/");

        // Load and parse the data file
        val file = "/home/kzk/datasets/spark-sample/svm_data.txt";
        val data = sc.textFile(file);
        val parsedData = data.map { line =>
        val parts = line.split(' ')
        LabeledPoint(parts(0).toDouble, parts.tail.map(x => x.toDouble).toArray)
        }

        // Run training algorithm to build the model
        // with L1-norm
        val numIterations = 20;
        val svmAlg = new SVMWithSGD();
        svmAlg.optimizer.setNumIterations(numIterations)
        .setRegParam(0.1)
        .setUpdater(new L1Updater);
        
        val model = svmAlg.run(parsedData);

        // Evaluate model on training examples and compute training error
        val labelAndPreds = parsedData.map { point =>
        val prediction = model.predict(point.features)
        (point.label, prediction)
        }
        val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / parsedData.count;
        println("Training Error = " + trainErr)      
    }
}

Algorithm

  • SGDで解いていると言われても,Passive Agressive (PA)しか知らない.
  • HingeGradient#computeを見る限りだと,Passive Agressiveぽい. stepサイズ(=stepsize/sqrt(iter))を自分で決めているのでSGD.