2016-08-20 8 views
0

"스레드"main "예외"가 발생했습니다. java.lang.ClassCastException : org.apache.spark.ml.attribute.UnresolvedAttribute $를 org.apache로 형변환 할 수 없습니다. spark.ml.attribute.NominalAttribute ".ClassCastException을 가져 오는 데이터 집합에 GBT를 적용하려고 시도합니다.

소스 코드 나 스파크 자바를 사용하고

package com.spark.lograthmicregression; 

import java.text.ParseException; 
import java.text.SimpleDateFormat; 
import java.util.Calendar; 
import java.util.Date; 
import java.util.HashSet; 
import java.util.Set; 

import org.apache.spark.SparkConf; 
import org.apache.spark.api.java.JavaSparkContext; 
import org.apache.spark.ml.Pipeline; 
import org.apache.spark.ml.PipelineModel; 
import org.apache.spark.ml.PipelineStage; 
import org.apache.spark.ml.classification.GBTClassificationModel; 
import org.apache.spark.ml.classification.GBTClassifier; 
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; 
import org.apache.spark.ml.feature.IndexToString; 
import org.apache.spark.ml.feature.StringIndexer; 
import org.apache.spark.ml.feature.StringIndexerModel; 
import org.apache.spark.ml.feature.VectorAssembler; 
import org.apache.spark.sql.DataFrame; 
import org.apache.spark.sql.SQLContext; 
import org.apache.spark.sql.catalyst.expressions.AttributeReference; 
import org.apache.spark.sql.catalyst.expressions.Expression; 
import org.apache.spark.sql.types.DataType; 
import org.apache.spark.sql.types.DataTypes; 

import com.google.common.collect.ImmutableMap; 

import scala.collection.mutable.Seq; 

public class ClickThroughRateAnalytics { 

    private static SimpleDateFormat sdf = new SimpleDateFormat("yyMMddHH"); 

    public static void main(String[] args) { 

     final SparkConf sparkConf = new SparkConf().setAppName("Click Analysis").setMaster("local"); 

     try (JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf)) { 

      SQLContext sqlContext = new SQLContext(javaSparkContext); 
      DataFrame dataFrame = sqlContext.read().format("com.databricks.spark.csv").option("inferSchema", "true").option("header", "true") 
        .load("/splits/sub-suaa"); 

      // This will keep data in memory 
      dataFrame.cache(); 

      // This will describe the column 
      // dataFrame.describe("hour").show(); 

      System.out.println("Rows before removing missing data : " + dataFrame.count()); 

      // This will describe column details 
      // dataFrame.describe("click", "hour", "site_domain").show(); 

      // This will calculate variance between columns +ve one increases 
      // second increases and -ve means one increases other decreases 
      // double cov = dataFrame.stat().cov("click", "hour"); 
      // System.out.println("cov : " + cov); 

      // It provides quantitative measurements of the statistical 
      // dependence between two random variables 
      // double corr = dataFrame.stat().corr("click", "hour"); 
      // System.out.println("corr : " + corr); 

      // Cross Tabulation provides a table of the frequency distribution 
      // for a set of variables 
      // dataFrame.stat().crosstab("site_id", "site_domain").show(); 

      // For frequent items 
      // System.out.println("Frequest Items : " + 
      // dataFrame.stat().freqItems(new String[] { "site_id", 
      // "site_domain" }, 0.3).collectAsList()); 

      // TODO we can also set maximum occurring item to categorical 
      // values. 

      // This will replace null values with average for numeric columns 
      dataFrame = modifiyDatFrame(dataFrame); 

      // Removing rows which have some missing values 
      dataFrame = dataFrame.na().replace(dataFrame.columns(), ImmutableMap.of("", "NA")); 
      dataFrame.na().fill(0.0); 
      dataFrame = dataFrame.na().drop(); 

      System.out.println("Rows after removing missing data : " + dataFrame.count()); 

      // TODO Binning and bucketing 

      // normalizer will take the column created by the VectorAssembler, 
      // normalize it and produce a new column 
      // Normalizer normalizer = new 
      // Normalizer().setInputCol("features_index").setOutputCol("features"); 

      dataFrame = dataFrame.drop("app_category_index").drop("app_domain_index").drop("hour_index").drop("C20_index") 
        .drop("device_connection_type_index").drop("C1_index").drop("id").drop("device_ip_index").drop("banner_pos_index"); 
      DataFrame[] splits = dataFrame.randomSplit(new double[] { 0.7, 0.3 }); 
      DataFrame trainingData = splits[0]; 
      DataFrame testData = splits[1]; 

      StringIndexerModel labelIndexer = new StringIndexer().setInputCol("click").setOutputCol("indexedclick").fit(dataFrame); 
      // Here we will be sending all columns which will participate in 
      // prediction 
      VectorAssembler vectorAssembler = new VectorAssembler().setInputCols(findPredictionColumns("click", dataFrame)) 
        .setOutputCol("features_index"); 

      GBTClassifier gbt = new GBTClassifier().setLabelCol("indexedclick").setFeaturesCol("features_index").setMaxIter(10).setMaxBins(69000); 

      IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel"); 
      Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { labelIndexer, vectorAssembler, gbt, labelConverter }); 

      trainingData.show(1); 
      PipelineModel model = pipeline.fit(trainingData); 
      DataFrame predictions = model.transform(testData); 
      predictions.select("predictedLabel", "label").show(5); 
      MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel") 
        .setPredictionCol("prediction").setMetricName("precision"); 
      double accuracy = evaluator.evaluate(predictions); 
      System.out.println("Test Error = " + (1.0 - accuracy)); 

      GBTClassificationModel gbtModel = (GBTClassificationModel) (model.stages()[2]); 

      System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString()); 

     } 
    } 

    private static String[] findPredictionColumns(String outputCol, DataFrame dataFrame) { 
     String columns[] = dataFrame.columns(); 
     String inputColumns[] = new String[columns.length - 1]; 
     int count = 0; 
     for (String column : dataFrame.columns()) { 
      if (!column.equalsIgnoreCase(outputCol)) { 
       inputColumns[count++] = column; 
      } 
     } 
     return inputColumns; 
    } 

    /** 
    * This will replace empty values with mean. 
    * 
    * @param columnName 
    * @param dataFrame 
    * @return 
    */ 
    private static DataFrame modifiyDatFrame(DataFrame dataFrame) { 
     Set<String> numericColumns = new HashSet<String>(); 
     if (dataFrame.numericColumns() != null && dataFrame.numericColumns().length() > 0) { 
      scala.collection.Iterator<Expression> iterator = ((Seq<Expression>) dataFrame.numericColumns()).toIterator(); 
      while (iterator.hasNext()) { 
       Expression expression = iterator.next(); 
       Double avgAge = dataFrame.na().drop().groupBy(((AttributeReference) expression).name()).avg(((AttributeReference) expression).name()) 
         .first().getDouble(1); 
       dataFrame = dataFrame.na().fill(avgAge, new String[] { ((AttributeReference) expression).name() }); 
       numericColumns.add(((AttributeReference) expression).name()); 

       DataType dataType = ((AttributeReference) expression).dataType(); 
       if (!"double".equalsIgnoreCase(dataType.simpleString())) { 
        dataFrame = dataFrame.withColumn("temp", dataFrame.col(((AttributeReference) expression).name()).cast(DataTypes.DoubleType)) 
          .drop(((AttributeReference) expression).name()).withColumnRenamed("temp", ((AttributeReference) expression).name()); 
       } 
      } 
     } 

     // Fit method of StringIndexer converts the column to StringType(if 
     // it is not of StringType) and then counts the occurrence of each 
     // word. It then sorts these words in descending order of their 
     // frequency and assigns an index to each word. StringIndexer.fit() 
     // method returns a StringIndexerModel which is a Transformer 
     StringIndexer stringIndexer = new StringIndexer(); 
     String allCoumns[] = dataFrame.columns(); 
     for (String column : allCoumns) { 
      if (!numericColumns.contains(column)) { 
       dataFrame = stringIndexer.setInputCol(column).setOutputCol(column + "_index").fit(dataFrame).transform(dataFrame); 
       dataFrame = dataFrame.drop(column); 
      } 
     } 

     dataFrame.printSchema(); 
     return dataFrame; 
    } 

    @SuppressWarnings("unused") 
    private static void copyFile(DataFrame dataFrame) { 
     dataFrame 
       .select("id", "click", "hour", "C1", "banner_pos", "site_id", "site_domain", "site_category", "app_id", "app_domain", "app_category", 
         "device_id", "device_ip", "device_model", "device_type", "device_conn_type", "C14", "C15", "C16", "C17", "C18", "C19", "C20", 
         "C21") 
       .write().format("com.databricks.spark.csv").option("header", "true").option("codec", "org.apache.hadoop.io.compress.GzipCodec") 
       .save("/splits/sub-splitaa-optmized"); 
    } 

    @SuppressWarnings("unused") 
    private static Integer parse(String sDate, int field) { 
     try { 
      if (sDate != null && !sDate.toString().equalsIgnoreCase("hour")) { 
       Date date = sdf.parse(sDate.toString()); 
       Calendar cal = Calendar.getInstance(); 
       cal.setTime(date); 
       return cal.get(field); 
      } 
     } catch (ParseException e) { 
      e.printStackTrace(); 
     } 
     return 0; 
    } 

} 

. 시간, C1, banner_pos, 사이트 ID, SITE_DOMAIN, site_category, APP_ID, APP_DOMAIN, APP_CATEGORY, DEVICE_ID, device_ip, DEVICE_MODEL, DEVICE_TYPE, device_conn_type, C14, C15, C16, C17, C18,

ID를 클릭 : 샘플 파일이있을 것입니다 , C19, C20, C21 100000941815109427301410210010050,1fbe01fe, f3845767,28905ebd, ecad2386,7801e8d9,07d7df22, a99f214a, ddd2926e, 44956a24,1,2,157063205017220,35 , -1,79 1000016934911786371501410210010050,1fbe01fe, f3845767,28905ebd, ecad2386,7801e8d9,07d7df22, a99f214a, 96809ac8,711ee12010157043205017220,35, 100084,79 1000037190421511948601410210010050,1fbe01fe, f3845767,28905ebd, ecad2386,7801e8d9,07d7df22, a99f214a, b3cf8def, 8a4875bd, 1,015704320501722035100084, 79 1000064072448083837601410210010050,1fbe01fe, f3845767,28905ebd, ecad2386, 7801e8d9,07d7df22, a99f214a, e8275b8f, 6332421a, 1,015706320501722035100084,79 100006790564170420960141021001005,1, fe8cc448,9166c161,0569f928, ecad2386,7801e8d9, 07d7df22, a99f214a, 9644d0bf, 779d90c2,1,01813205021610,35, -1,157 1000072075780110386901410210010050, d6137915, bb1ef334, f028772b, ecad2386,7801e8d9,07d7df22, a99f214a , 05241af0,8a4875bd, 1,0169203205018990431100077,117 1000072472998854491101410210010050,8fda644b, 25d4cfcd, f028772b, ecad2386,7801e8d9,07d7df22, a99f214a, b264c159, be6db1d7 , 1,0203623205023330,39, -1,157

+0

이 문제를 해결할 수 있었습니까? 나는 NaiveBayes와 같은 것을보고있다. Spark에 대한 보류중인 PR이 있지만 활동이 없습니다. – Bob

+0

예 bin = (모든 열에서 최대 고유 범주 수)를 설정해야하지만 이후에는 스택 오버플로 예외가 발생합니다. 나는 충분한 기억이 없을 수도있다. – cody123

답변

0

답장이 늦었지만 CSV 파일의 데이터 세트에 gbt를 사용하는 중에도 동일한 오류가 발생했습니다. labelConverter에 .setLabels (labelIndexer.labels())를 추가 했으므로 문제가 해결되었습니다.

IndexToString labelConverter = new IndexToString() 
           .setInputCol("prediction") 
           .setOutputCol("predictedLabel") 
           .setLabels(labelIndexer.labels())