|
1 | 1 | """Utils functions for Hadoop. |
2 | 2 | """ |
3 | 3 | from __future__ import annotations |
4 | | -from typing import Union |
| 4 | +from typing import Optional, Union |
5 | 5 | import sys |
6 | 6 | import datetime |
7 | 7 | from pyspark.sql import DataFrame, Window |
8 | | -from pyspark.sql.functions import col, spark_partition_id, rank, coalesce, lit, max, sum |
| 8 | +import pyspark.sql.functions as sf |
9 | 9 |
|
10 | 10 |
|
11 | 11 | def sample( |
@@ -50,55 +50,81 @@ def calc_global_rank(frame: DataFrame, order_by: Union[str, list[str]]) -> DataF |
50 | 50 | # calculate local rank |
51 | 51 | wspec1 = Window.partitionBy("part_id").orderBy(*order_by) |
52 | 52 | frame_local_rank = frame.orderBy(order_by).withColumn( |
53 | | - "part_id", spark_partition_id() |
| 53 | + "part_id", sf.spark_partition_id() |
54 | 54 | ).withColumn("local_rank", |
55 | | - rank().over(wspec1)).persist() |
| 55 | + sf.rank().over(wspec1)).persist() |
56 | 56 | # calculate accumulative rank |
57 | 57 | wspec2 = Window.orderBy("part_id").rowsBetween( |
58 | 58 | Window.unboundedPreceding, Window.currentRow |
59 | 59 | ) |
60 | 60 | stat = frame_local_rank.groupBy("part_id").agg( |
61 | | - max("local_rank").alias("max_rank") |
| 61 | + sf.max("local_rank").alias("max_rank") |
62 | 62 | ).withColumn("cum_rank", |
63 | | - sum("max_rank").over(wspec2)) |
| 63 | + sf.sum("max_rank").over(wspec2)) |
64 | 64 | # self join and shift 1 row to get sum factor |
65 | 65 | stat2 = stat.alias("l").join( |
66 | 66 | stat.alias("r"), |
67 | | - col("l.part_id") == col("r.part_id") + 1, "left_outer" |
68 | | - ).select(col("l.part_id"), |
69 | | - coalesce(col("r.cum_rank"), lit(0)).alias("sum_factor")) |
| 67 | + sf.col("l.part_id") == sf.col("r.part_id") + 1, "left_outer" |
| 68 | + ).select( |
| 69 | + sf.col("l.part_id"), |
| 70 | + sf.coalesce(sf.col("r.cum_rank"), sf.lit(0)).alias("sum_factor") |
| 71 | + ) |
70 | 72 | return frame_local_rank.join( |
71 | 73 | #broadcast(stat2), |
72 | 74 | stat2, |
73 | 75 | ["part_id"], |
74 | 76 | ).withColumn("rank", |
75 | | - col("local_rank") + col("sum_factor")) |
| 77 | + sf.col("local_rank") + sf.col("sum_factor")) |
76 | 78 |
|
77 | 79 |
|
78 | | -def repart_hdfs(spark, path: str, num_parts: int, coalesce: bool = False) -> None: |
| 80 | +def repart_hdfs( |
| 81 | + spark, |
| 82 | + src_path: str, |
| 83 | + dst_path: str = "", |
| 84 | + num_parts: Optional[int] = None, |
| 85 | + mb_per_part: float = 64, |
| 86 | + min_num_parts: int = 1, |
| 87 | + coalesce: bool = False |
| 88 | +) -> None: |
79 | 89 | """Repartition a HDFS path of the Parquet format. |
80 | 90 |
|
81 | 91 | :param spark: A SparkSession object. |
82 | 92 | :param path: The HDFS path to repartition. |
83 | 93 | :param num_parts: The new number of partitions. |
84 | 94 | :param coalesce: If True, use coalesce instead of repartition. |
85 | 95 | """ |
86 | | - path = path.rstrip("/") |
87 | | - ts = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f") |
88 | | - path_tmp = path + f"_repart_tmp_{ts}" |
| 96 | + sc = spark.sparkContext |
| 97 | + hdfs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration()) # pylint: disable=W0212 |
| 98 | + src_path = src_path.rstrip("/") |
| 99 | + src_path_hdfs = sc._jvm.org.apache.hadoop.fs.Path(src_path) # pylint: disable=W0212 |
| 100 | + # num of partitions |
| 101 | + if num_parts is None: |
| 102 | + bytes_path = hdfs.getContentSummary(src_path_hdfs).getLength() |
| 103 | + num_parts = round(bytes_path / 1_048_576 / mb_per_part) |
| 104 | + num_parts = max(num_parts, min_num_parts) |
| 105 | + # temp path for repartitioned table |
| 106 | + if dst_path == src_path: |
| 107 | + dst_path = "" |
| 108 | + if dst_path: |
| 109 | + path_tmp = dst_path |
| 110 | + else: |
| 111 | + ts = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f") |
| 112 | + path_tmp = src_path + f"_repart_tmp_{ts}" |
| 113 | + # repartition |
89 | 114 | if coalesce: |
90 | | - spark.read.parquet(path).coalesce(num_parts) \ |
| 115 | + spark.read.parquet(src_path).coalesce(num_parts) \ |
91 | 116 | .write.mode("overwrite").parquet(path_tmp) |
92 | 117 | else: |
93 | | - spark.read.parquet(path).repartition(num_parts) \ |
| 118 | + spark.read.parquet(src_path).repartition(num_parts) \ |
94 | 119 | .write.mode("overwrite").parquet(path_tmp) |
95 | | - sc = spark.sparkContext |
96 | | - fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration()) # pylint: disable=W0212 |
97 | | - if fs.delete(sc._jvm.org.apache.hadoop.fs.Path(path), True): # pylint: disable=W0212 |
98 | | - if not fs.rename( |
| 120 | + # path_tmp --> src_path |
| 121 | + if dst_path: |
| 122 | + return |
| 123 | + if hdfs.delete(src_path_hdfs, True): |
| 124 | + if not hdfs.rename( |
99 | 125 | sc._jvm.org.apache.hadoop.fs.Path(path_tmp), # pylint: disable=W0212 |
100 | | - sc._jvm.org.apache.hadoop.fs.Path(path), # pylint: disable=W0212 |
| 126 | + src_path_hdfs, # pylint: disable=W0212 |
101 | 127 | ): |
102 | | - sys.exit(f"Failed to rename the HDFS path {path_tmp} to {path}!") |
| 128 | + sys.exit(f"Failed to rename the HDFS path {path_tmp} to {src_path}!") |
103 | 129 | else: |
104 | | - sys.exit(f"Failed to remove the (old) HDFS path: {path}!") |
| 130 | + sys.exit(f"Failed to remove the (old) HDFS path: {src_path}!") |
0 commit comments