PySpark列表中具有一个或两个元素的元素的区别

我有一个PySpark数据框,其中包含一个或两个元素的列表。当列表中有两个元素时,它们不是按升序或降序排列的。

+--------+----------+-------+
| version| timestamp| list  |
+--------+-----+----|-------+
| v1     |2012-01-10| [5,2] |
| v1     |2012-01-11| [2,5] |
| v1     |2012-01-12| [3,2] |
| v2     |2012-01-12| [2]   |
| v2     |2012-01-11| [1,2] |
| v2     |2012-01-13| [1]   |
+--------+----------+-------+

I want to take difference betweeen the first and the second elements of the list (when there are two elements) and have that as another column (diff). When there is only one element in the list, I want to put zero in output. Here is an example of the output that I want.

+--------+----------+-------+-------+
| version| timestamp| list  |  diff | 
+--------+-----+----|-------+-------+
| v1     |2012-01-10| [5,2] |   3   |
| v1     |2012-01-11| [2,5] |  -3   |
| v1     |2012-01-12| [3,2] |   1   |
| v2     |2012-01-12| [2]   |   0   |
| v2     |2012-01-11| [1,2] |  -1   |
| v2     |2012-01-13| [1]   |   0   |
+--------+----------+-------+-------+

My question is similar to this question which I asked before, but is not exactly the same.

如何使用PySpark做到这一点?

我也愿意使用UDF获得预期的输出,以防万一。

都欢迎不使用UDF的方法以及基于UDF的方法。谢谢。

评论
  • 纯情小火鸡
    纯情小火鸡 回复

    Adding to @Shu's answer from earlier, just add a when/otherwise clause to that by checking size of array.

    df.withColumn("diff", F.when(F.size('list')==2, F.expr("""transform(array(list),x-> x[0]-x[1])""")[0])\
                           .otherwise(F.lit(0))).show()
    
    #+------+----+
    #|  list|diff|
    #+------+----+
    #|[5, 2]|   3|
    #|[2, 5]|  -3|
    #|   [2]|   0|
    #+------+----+
    
  • 妮听咱话
    妮听咱话 回复

    You could also define a udf like so

    示例数据

    data = [
        ('v1', [5, 2],),
        ('v1', [2, 5],),
        ('v1', [3, 2],),
        ('v2', [2],),
        ('v2', [1, 2],),
        ('v2', [1],),
    ]
    df = spark.createDataFrame(data, ['version', 'list'])
    

    from functools import reduce
    from pyspark.sql.functions import udf
    from pyspark.sql.types import IntegerType
    
    # UDF definition
    find_diff = udf(lambda a: reduce(lambda x, y: x - y, a), IntegerType()) 
    
    (
        df.
            withColumn(
                'diff',
                find_diff('list')
            ).
            show(truncate=False)
    )
    
    +-------+------+----+                                                           
    |version|list  |diff|
    +-------+------+----+
    |v1     |[5, 2]|3   |
    |v1     |[2, 5]|-3  |
    |v1     |[3, 2]|1   |
    |v2     |[2]   |2   |
    |v2     |[1, 2]|-1  |
    |v2     |[1]   |1   |
    +-------+------+----+