I have newly started working in spark-scala
. I have a requirement where in I need to find the sum
for few columns within a case statement. I’ve written the corresponding spark-sql
code but unable to implement the same in spark-scala
dynamically. Below is what I’m trying to achieve –
SQL Code–
Select col_A,
round(case when sum(amt_M) <> 0.0 then sum(amt_M)
when sum(amt_N) <> 0.0 then sum(amt_N)
when sum(amt_P) <> 0.0 then sum(amt_P)
end,1) as pct
from table_T1
group by col_A
The use case is to get certain columns from a variable to implement the case-statement
logic as above dynamically. Having said that, currently considering there are 3 columns however, that number could increase later on.
Below is the code I tried to implement in spark-scala
–
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.collection._
val df = spark.table("database.table_T1")
val cols = "amt_M,amt_N,amt_P"
val aggCols = cols.split(",").toSeq
val sums = aggCols.map(colName => when(round(sum(colName).cast(DoubleType),1) =!= 0.0,sum(colName).cast(DoubleType).alias("sum_"+colName)))
val df2 = df.groupBy(col("col_A")).agg(sums.head, sums.tail:_*)
However, this is not giving the desired results. Please help me on this.
Input Data
+--------+--------------------+---------------------+----------------------+
|col_A |amt_M |amt_N |amt_P |
+--------+--------------------+---------------------+----------------------+
|5C-SVS-1|0.0 |0.04064912622009295 |1.6256888829356116E-4 |
|5C-SVS-1|0.0 |0.026542159153759487 |8.574900251977566E-4 |
|5C-SVS-1|0.0 |5.703894148377958E-5 |1.0745888408402782E-7 |
|5C-SVS-1|0.0 |0.0 |4.514561031069833E-4 |
|5C-SVS-1|0.0 |0.011794053124022862 |0.0020388259536434656 |
|5C-SVS-1|0.0 |7.55793849084569E-4 |0.0017105736019335327 |
|5C-SVS-1|0.0 |0.019303776946698548 |2.240625765755109E-5 |
|5C-SVS-1|0.0 |-8.028117213883126E-6|-2.1979360825171534E-6|
|5C-SVS-1|0.001940948839163001|0.029163686986129422 |0.09505621692309557 |
|5C-SVS-1|0.0 |2.515835289984397E-7 |1.1486227577926157E-8 |
|5C-SVS-1|0.0 |0.007844299114837874 |9.974187712854785E-4 |
|5C-SVS-1|0.0 |5.033123682586349E-4 |1.3644443189731007E-4 |
|5C-SVS-1|0.0 |0.026331681277001386 |6.022434166108063E-4 |
|5C-SVS-1|0.0 |8.098023638080503E-6 |1.0 |
|5C-SVS-1|0.0 |0.03655893437209876 |0.003113370686486882 |
|5C-SVS-1|0.0 |0.01409363925733864 |6.239415097038338E-4 |
|5C-SVS-1|0.0 |0.02171856350557304 |0.0 |
|5C-SVS-1|0.008435341548288601|0.03347191686227869 |0.35221710556006247 |
|5C-SVS-1|0.0 |-2.547132732700875E-6|-0.13073525789233997 |
|5C-SVS-1|0.006057441518729214|0.024036273783621134 |0.21447606070652467 |
+--------+--------------------+---------------------+----------------------+
Expected Output
+--------+---+
| col_A|pct|
+--------+---+
|5C-SVS-1|1.0|
+--------+---+
Thanks
2
Answers
I solved the requirement by implementing the below method -
Now once this is being called over during the aggregation as below -
You could first
groupBy
your Dataframe oncol_A
, calculate the sums and afterwards use amap
operation to choose which sum you want to take with you. Something like this:Note: What do you do if all of the sums == 0? That’s up to you to decide: I put the value of
sum_amt_P
as theelse
catch-all case. But from here on you can just adapt the logic inside of themap
function to get whatever you want.Hope this helps!