Pyspark基于具有列表或集合的多个条件的其他列创建新列

我正在尝试在pyspark数据框中创建一个新列。我有以下数据

+------+
|letter|
+------+
|     A|
|     C|
|     A|
|     Z|
|     E|
+------+

我想根据给定的列添加一个新列

+------+-----+
|letter|group|
+------+-----+
|     A|   c1|
|     B|   c1|
|     F|   c2|
|     G|   c2|
|     I|   c3|
+------+-----+

可以有多个类别,其中有多个字母的单个值(大约100个,也包含多个字母)

我已经用udf完成了这项工作,并且运行良好

from pyspark.sql.functions import udf
from pyspark.sql.types import *

c1 = ['A','B','C','D']
c2 = ['E','F','G','H']
c3 = ['I','J','K','L']
...

def l2c(value):
    if value in c1: return 'c1'
    elif value in c2: return 'c2'
    elif value in c3: return 'c3'
    else: return "na"

udf_l2c = udf(l2c, StringType())
data_with_category = data.withColumn("group", udf_l2c("letter"))

Now I am trying to do it without udf. Maybe using when and col. What I have tried is following. It is working, but very long code.

data_with_category = data.withColumn('group', when(col('letter') == 'A' ,'c1')
    .when(col('letter') == 'B', 'c1')
    .when(col('letter') == 'F', 'c2')
    ... 

当字母的所有可能值都满足条件时,我不是很好写。在我的情况下,字母的数量可能非常大(大约100个)。所以我尝试了

data_with_category = data.withColumn('group', when(col('letter') in ['A','B','C','D'] ,'c1')
    .when(col('letter') in ['E','F','G','H'], 'c2')
    .when(col('letter') in ['I','J','K','L'], 'c3')

但是它返回错误。我该如何解决?

评论
  • xid
    xid 回复

    您可以尝试使用udf, 例如:

    say_hello_udf = udf(lambda name: say_hello(name), StringType())
    df = spark.createDataFrame([("Rick,"),("Morty,")], ["name"])
    df.withColumn("greetings", say_hello_udf(col("name")).show()
    

    要么

    @udf(returnType=StringType())
    def say_hello(name):
       return f"Hello {name}"
    df.withColumn("greetings", say_hello(col("name")).show()