|
3 | 3 | import re
|
4 | 4 | import shutil
|
5 | 5 |
|
6 |
| -try: |
7 |
| - import sh |
8 |
| -except ImportError as e: |
9 |
| - raise ImportError( |
10 |
| - "Please install sh in your Python environment.\n" |
11 |
| - " - Pip: pip install sh\n" |
12 |
| - " - Conda: conda install -c conda-forge sh" |
13 |
| - ) from e |
14 |
| - |
15 | 6 |
|
16 | 7 | def main(args):
|
| 8 | + if args.scala_version == "2.12": |
| 9 | + scala_ver = "2.12" |
| 10 | + scala_patchver = "2.12.18" |
| 11 | + elif args.scala_version == "2.13": |
| 12 | + scala_ver = "2.13" |
| 13 | + scala_patchver = "2.13.11" |
| 14 | + else: |
| 15 | + raise ValueError(f"Unsupported Scala version: {args.scala_version}") |
| 16 | + |
17 | 17 | # Clean artifacts
|
18 |
| - for target in pathlib.Path("jvm-packages/").glob("**/target"): |
19 |
| - if target.is_dir(): |
20 |
| - print(f"Removing {target}...") |
21 |
| - shutil.rmtree(target) |
| 18 | + if args.purge_artifacts: |
| 19 | + for target in pathlib.Path("jvm-packages/").glob("**/target"): |
| 20 | + if target.is_dir(): |
| 21 | + print(f"Removing {target}...") |
| 22 | + shutil.rmtree(target) |
22 | 23 |
|
23 | 24 | # Update pom.xml
|
24 | 25 | for pom in pathlib.Path("jvm-packages/").glob("**/pom.xml"):
|
25 | 26 | print(f"Updating {pom}...")
|
26 |
| - sh.sed( |
27 |
| - [ |
28 |
| - "-i", |
29 |
| - f"s/<artifactId>xgboost-jvm_[0-9\\.]*/<artifactId>xgboost-jvm_{args.scala_version}/g", |
30 |
| - str(pom), |
31 |
| - ] |
32 |
| - ) |
| 27 | + with open(pom, "r", encoding="utf-8") as f: |
| 28 | + lines = f.readlines() |
| 29 | + with open(pom, "w", encoding="utf-8") as f: |
| 30 | + replaced_scalaver = False |
| 31 | + replaced_scala_binver = False |
| 32 | + for line in lines: |
| 33 | + for artifact in [ |
| 34 | + "xgboost-jvm", |
| 35 | + "xgboost4j", |
| 36 | + "xgboost4j-gpu", |
| 37 | + "xgboost4j-spark", |
| 38 | + "xgboost4j-spark-gpu", |
| 39 | + "xgboost4j-flink", |
| 40 | + "xgboost4j-example", |
| 41 | + ]: |
| 42 | + line = re.sub( |
| 43 | + f"<artifactId>{artifact}_[0-9\\.]*", |
| 44 | + f"<artifactId>{artifact}_{scala_ver}", |
| 45 | + line, |
| 46 | + ) |
| 47 | + # Only replace the first occurrence of scala.version |
| 48 | + if not replaced_scalaver: |
| 49 | + line, nsubs = re.subn( |
| 50 | + r"<scala.version>[0-9\.]*", |
| 51 | + f"<scala.version>{scala_patchver}", |
| 52 | + line, |
| 53 | + ) |
| 54 | + if nsubs > 0: |
| 55 | + replaced_scalaver = True |
| 56 | + # Only replace the first occurrence of scala.binary.version |
| 57 | + if not replaced_scala_binver: |
| 58 | + line, nsubs = re.subn( |
| 59 | + r"<scala.binary.version>[0-9\.]*", |
| 60 | + f"<scala.binary.version>{scala_ver}", |
| 61 | + line, |
| 62 | + ) |
| 63 | + if nsubs > 0: |
| 64 | + replaced_scala_binver = True |
| 65 | + f.write(line) |
33 | 66 |
|
34 | 67 |
|
35 | 68 | if __name__ == "__main__":
|
36 | 69 | parser = argparse.ArgumentParser()
|
| 70 | + parser.add_argument("--purge-artifacts", action="store_true") |
37 | 71 | parser.add_argument(
|
38 | 72 | "--scala-version",
|
39 | 73 | type=str,
|
|
0 commit comments