skip to Main Content

Input should be as below:

company sales
amazon 100
flipkart 900
ebay 890
amazon 100
flipkart 100
ebay 10
amazon 100
flipkart 90
ebay 10

And expected output should be as below:

amazon flipkart ebay
300 1090 910

Tried using pivot function, but its not working. Any help on this would be appreciated.. Thanks in advance..

3

Answers


  1. I’ve been using the following test dataset to compose this solution:

    data = [("amazon", 100), ("flipkart", 300), ("amazon", 50), ("ebay", 50), ("ebay", 150), ("amazon", 300)]
    columns= ["company", "sales"]
    df = spark.createDataFrame(data = data, schema = columns)
    

    This will result in the following dataframe:

    +--------+-----+
    |company |sales|
    +--------+-----+
    |amazon  |100  |
    |flipkart|300  |
    |amazon  |50   |
    |ebay    |50   |
    |ebay    |150  |
    |amazon  |300  |
    +--------+-----+
    

    For your specific wide table I would do the following:

    df 
      .groupBy("company") 
      .pivot("company") 
      .sum("sales")
    

    Spark will save keep the null values. The function pivot is really expensive on its own.

    To really get your result you can do the following for removing the null values by grouping and summing again.

    companies = list(
        df.select('company').toPandas()['company'].unique()
    )
    pdf = df 
      .groupBy("company") 
      .pivot("company") 
      .sum("sales") 
      .groupBy() 
      .agg(*[F.sum(c).alias(c) for c in companies])
    

    Result is:

    +------+--------+----+
    |amazon|flipkart|ebay|
    +------+--------+----+
    |   450|     300| 200|
    +------+--------+----+
    
    Login or Signup to reply.
  2. Use groupBy() and first() function.

    data = [("amazon", 100), ("flipkart", 300), ("amazon", 50), ("ebay", 50), ("ebay", 150), ("amazon", 300)]
    df = spark.createDataFrame(data,["company","sales"])
    df.show()
    +--------+-----+
    | company|sales|
    +--------+-----+
    |  amazon|  100|
    |flipkart|  300|
    |  amazon|   50|
    |    ebay|   50|
    |    ebay|  150|
    |  amazon|  300|
    +--------+-----+
    
    df.groupBy("company").agg(sum("sales").alias("sales"))
    .groupBy().pivot("company").agg(first("sales")).show()
    
    
    +------+----+--------+
    |amazon|ebay|flipkart|
    +------+----+--------+
    |   450| 200|     300|
    +------+----+--------+
    

    Thanks to @andy, for the helper column solution

    df.groupBy(lit(0).alias("Key")).pivot("company")
    .agg(sum("sales").alias("sales")).show()
    
    +---+------+----+--------+
    |Key|amazon|ebay|flipkart|
    +---+------+----+--------+
    |  0|   450| 200|     300|
    +---+------+----+--------+
    

    The key column can be removed using drop() function

    df.groupBy(lit(0).alias("Key")).pivot("company")
    .agg(sum("sales").alias("sales")).drop("key").show()
    +------+----+--------+
    |amazon|ebay|flipkart|
    +------+----+--------+
    |   450| 200|     300|
    +------+----+--------+
    
    Login or Signup to reply.
  3. You can use pivot() function as shown in other answers.

    Here is the shorter version of groupBy() and pivot() combination.

    >>> data = [("amazon", 100), ("flipkart", 300), ("amazon", 50), ("ebay", 50), ("ebay", 150), ("amazon", 300)]
    >>> df = spark.createDataFrame(data,["company","sales"])
    >>> df.groupBy().pivot('company').sum('sales').show()
    +------+----+--------+
    |amazon|ebay|flipkart|
    +------+----+--------+
    |   450| 200|     300|
    +------+----+--------+
    
    Login or Signup to reply.
Please signup or login to give your own answer.
Back To Top
Search