为什么在Spark环境中无法正确加载保存的本地LDA模型?

 收藏

我正在学习和实现spark mlib函数,无法加载已保存的lda模型。

我正在使用Spark 2.1.0和Scala 2.11.11

如果我在同一堂课上训练后立即尝试使用该模型

val lda = new LDA().setK(20).setMaxIter(10)
val model = lda.fit(ldadata)
val topics = model.describeTopics(10)
println("The topics described by their top-weighted terms:")
topics.show(false)

我可以轻松获得所需的正确输出。

但是如果我保存模型

model.write.overwrite().save("/usr/local/LDAmodel400w")

然后尝试将其加载到另一个类中

val model = LocalLDAModel.load("/usr/local/LDAmodel400w")  
val topics = model.describeTopics(10)
println("The topics described by their top-weighted terms:")
topics.show(false)

我会得到一个

Exception in thread "main" java.lang.UnsupportedOperationException: empty collection
        at org.apache.spark.rdd.RDD$$anonfun$first$1.apply(RDD.scala:1369)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
        at org.apache.spark.rdd.RDD.withScope(RDD.scala:362)
        at org.apache.spark.rdd.RDD.first(RDD.scala:1366)
        at org.apache.spark.ml.util.DefaultParamsReader$.loadMetadata(ReadWrite.scala:379)
        at org.apache.spark.ml.clustering.DistributedLDAModel$DistributedLDAModelReader.load(LDA.scala:767)
        at org.apache.spark.ml.clustering.DistributedLDAModel$DistributedLDAModelReader.load(LDA.scala:762)
        at org.apache.spark.ml.util.MLReadable$class.load(ReadWrite.scala:215)
        at org.apache.spark.ml.clustering.DistributedLDAModel$.load(LDA.scala:781)
        at ldatest.ldamaven.loadlda$.main(loadlda.scala:31)
        at ldatest.ldamaven.loadlda.main(loadlda.scala)
        at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:498)
        at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:738)
        at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:187)
        at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:212)
        at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:126)
        at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)

我想知道我在哪里做错了。

回复