skip to Main Content

I have newly started working in spark-scala. I have a requirement where in I need to find the sum for few columns within a case statement. I’ve written the corresponding spark-sql code but unable to implement the same in spark-scala dynamically. Below is what I’m trying to achieve –

SQL Code

Select  col_A,
        round(case when sum(amt_M)   <> 0.0 then sum(amt_M) 
                   when sum(amt_N)   <> 0.0 then sum(amt_N)
                   when sum(amt_P)   <> 0.0 then sum(amt_P) 
              end,1) as pct 
from table_T1
group by col_A

The use case is to get certain columns from a variable to implement the case-statement logic as above dynamically. Having said that, currently considering there are 3 columns however, that number could increase later on.

Below is the code I tried to implement in spark-scala

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.collection._

val df = spark.table("database.table_T1")

val cols = "amt_M,amt_N,amt_P"

val aggCols = cols.split(",").toSeq

val sums = aggCols.map(colName => when(round(sum(colName).cast(DoubleType),1) =!= 0.0,sum(colName).cast(DoubleType).alias("sum_"+colName)))

val df2 = df.groupBy(col("col_A")).agg(sums.head, sums.tail:_*)

However, this is not giving the desired results. Please help me on this.

Input Data

+--------+--------------------+---------------------+----------------------+
|col_A   |amt_M               |amt_N                |amt_P                 |
+--------+--------------------+---------------------+----------------------+
|5C-SVS-1|0.0                 |0.04064912622009295  |1.6256888829356116E-4 |
|5C-SVS-1|0.0                 |0.026542159153759487 |8.574900251977566E-4  |
|5C-SVS-1|0.0                 |5.703894148377958E-5 |1.0745888408402782E-7 |
|5C-SVS-1|0.0                 |0.0                  |4.514561031069833E-4  |
|5C-SVS-1|0.0                 |0.011794053124022862 |0.0020388259536434656 |
|5C-SVS-1|0.0                 |7.55793849084569E-4  |0.0017105736019335327 |
|5C-SVS-1|0.0                 |0.019303776946698548 |2.240625765755109E-5  |
|5C-SVS-1|0.0                 |-8.028117213883126E-6|-2.1979360825171534E-6|
|5C-SVS-1|0.001940948839163001|0.029163686986129422 |0.09505621692309557   |
|5C-SVS-1|0.0                 |2.515835289984397E-7 |1.1486227577926157E-8 |
|5C-SVS-1|0.0                 |0.007844299114837874 |9.974187712854785E-4  |
|5C-SVS-1|0.0                 |5.033123682586349E-4 |1.3644443189731007E-4 |
|5C-SVS-1|0.0                 |0.026331681277001386 |6.022434166108063E-4  |
|5C-SVS-1|0.0                 |8.098023638080503E-6 |1.0                   |
|5C-SVS-1|0.0                 |0.03655893437209876  |0.003113370686486882  |
|5C-SVS-1|0.0                 |0.01409363925733864  |6.239415097038338E-4  |
|5C-SVS-1|0.0                 |0.02171856350557304  |0.0                   |
|5C-SVS-1|0.008435341548288601|0.03347191686227869  |0.35221710556006247   |
|5C-SVS-1|0.0                 |-2.547132732700875E-6|-0.13073525789233997  |
|5C-SVS-1|0.006057441518729214|0.024036273783621134 |0.21447606070652467   |
+--------+--------------------+---------------------+----------------------+

Expected Output

+--------+---+
|   col_A|pct|
+--------+---+
|5C-SVS-1|1.0|
+--------+---+

Thanks

2

Answers


  1. Chosen as BEST ANSWER

    I solved the requirement by implementing the below method -

    import org.apache.spark.sql.types._
    
    def getSumCols(columnList: List[String]): Column = {
    
    // Storing the value for the 1st index 
    
        var conditionColumn: Column = when(sum(col(columnList(0)).cast(DoubleType)) =!= 0.0, sum(col(columnList(0)).cast(DoubleType)))
    
    // Iterating through the 2nd element till end and appending to existing variable created in the 1st step
    
        for(c <- 1 to columnList.length -1){
            conditionColumn = conditionColumn.when( sum(col(columnList(c)).cast(DoubleType)) =!= 0.0, sum(col(columnList(c)).cast(DoubleType)) )
        }
        round(conditionColumn,1)
    }
    

    Now once this is being called over during the aggregation as below -

    val cols = "amt_M,amt_N,amt_P"
    
    val colList = cols.split(",").toList
    
    val conditionColumn: Column = getSumCols(colList)
    
    val df1 = df.groupBy("col_A").agg(conditionColumn.alias("pct"))
    

  2. You could first groupBy your Dataframe on col_A, calculate the sums and afterwards use a map operation to choose which sum you want to take with you. Something like this:

    import org.apache.spark.sql.types._
    import org.apache.spark.sql.Row
    
    // Creating the necessary schema to control the types read in when reading in our CSV
    val schema = new StructType()
        .add("col_A", StringType)
        .add("amt_M", DoubleType)
        .add("amt_N", DoubleType)
        .add("amt_P", DoubleType)
    
    // Reading in the Dataframe using our premade schema. I put the data in a CSV
    // file with ; as delimiters.
    val df = spark.read
        .option("header", "true")
        .option("sep",";")
        .schema(schema)
        .csv("./someData.csv")
    
    df.show
    +--------+--------------------+--------------------+--------------------+                                                                                                                                                                                                       
    |   col_A|               amt_M|               amt_N|               amt_P|                                                                                                                                                                                                       
    +--------+--------------------+--------------------+--------------------+                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0| 0.04064912622009295|1.625688882935611...|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|0.026542159153759487|8.574900251977566E-4|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|5.703894148377958E-5|1.074588840840278...|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|                 0.0|4.514561031069833E-4|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|0.011794053124022862|0.002038825953643...|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0| 7.55793849084569E-4|0.001710573601933...|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|0.019303776946698548|2.240625765755109E-5|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|-8.02811721388312...|-2.19793608251715...|                                                                                                                                                                                                       
    |5C-SVS-1|0.001940948839163001|0.029163686986129422| 0.09505621692309557|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|2.515835289984397E-7|1.148622757792615...|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|0.007844299114837874|9.974187712854785E-4|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|5.033123682586349E-4|1.364444318973100...|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|0.026331681277001386|6.022434166108063E-4|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|8.098023638080503E-6|                 1.0|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0| 0.03655893437209876|0.003113370686486882|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0| 0.01409363925733864|6.239415097038338E-4|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0| 0.02171856350557304|                 0.0|                                                                                                                                                                                                       
    |5C-SVS-1|0.008435341548288601| 0.03347191686227869| 0.35221710556006247|                                                                                                                                                                                                       
    |5C-SVS-1|                 0.0|-2.54713273270087...|-0.13073525789233997|                                                                                                                                                                                                       
    |5C-SVS-1|0.006057441518729214|0.024036273783621134| 0.21447606070652467|                                                                                                                                                                                                       
    +--------+--------------------+--------------------+--------------------+
    
    // Aggregating our data for each distinct value in col_A, summing all the amt columns
    val aggregated_df = df.groupBy(col("col_A"))
        .agg(
            round(sum(col("amt_M")).as("amt_M_sum"), 1),
            round(sum(col("amt_N")).as("amt_N_sum"), 1),
            round(sum(col("amt_P")).as("amt_P_sum"), 1)
    )
    
    aggregated_df.show                                                                                                                                                                                                                                                       
    +--------+---------------------------------+---------------------------------+---------------------------------+                                                                                                                                                                
    |   col_A|round(sum(amt_M) AS amt_M_sum, 1)|round(sum(amt_N) AS amt_N_sum, 1)|round(sum(amt_P) AS amt_P_sum, 1)|                                                                                                                                                                
    +--------+---------------------------------+---------------------------------+---------------------------------+                                                                                                                                                                
    |5C-SVS-1|                              0.0|                              0.3|                              1.5|                                                                                                                                                                
    +--------+---------------------------------+---------------------------------+---------------------------------+
    
    
    // Selecting our wanted values. We make use of Scala pattern matching here to
    // easily deconstruct our data and make something readable
    val output = aggregated_df.map(
        row => row match {
            case Row(col_A: String, sum_amt_M: Double, sum_amt_N: Double, sum_amt_P: Double) => {
                if (sum_amt_M != 0.0)
                    (col_A, sum_amt_M)
                else if (sum_amt_N != 0.0)
                    (col_A, sum_amt_N)
                else
                    (col_A, sum_amt_P)
            }
        }
    ).toDF("col_A", "pct")
    
    output.show                                                                                                                                                                                                                                                              
    +--------+---+                                                                                                                                                                                                                                                                  
    |   col_A|pct|                                                                                                                                                                                                                                                                  
    +--------+---+                                                                                                                                                                                                                                                                  
    |5C-SVS-1|0.3|                                                                                                                                                                                                                                                                  
    +--------+---+
    

    Note: What do you do if all of the sums == 0? That’s up to you to decide: I put the value of sum_amt_P as the else catch-all case. But from here on you can just adapt the logic inside of the map function to get whatever you want.

    Hope this helps!

    Login or Signup to reply.
Please signup or login to give your own answer.
Back To Top
Search