-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathpkgmgr.py
More file actions
133 lines (117 loc) · 6.94 KB
/
pkgmgr.py
File metadata and controls
133 lines (117 loc) · 6.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# This is a replica of https://raw.githubusercontent.com/fani-lab/OpeNTF/refs/heads/main/src/pkgmgr.py
import subprocess, sys, importlib, random, numpy, logging, re, os
log = logging.getLogger(__name__)
from omegaconf import OmegaConf
from itertools import chain
from importlib.metadata import version
def is_version_equal(inst_ver: str, req_ver: str) -> bool:
# Check for trailing .* in version_req
if req_ver.endswith('.*'): version_req = req_ver[:-2] + inst_ver[inst_ver.rfind('.'):]
return inst_ver == req_ver
def install_pkg(pkg_name):
log.info(f'Installing {pkg_name}...')
if pkg_req_dict[pkg_name][0].find('@https://raw.githubusercontent.com/') > -1 and pkg_req_dict[pkg_name][0].endswith('.py'):
wget_import(pkg_name, pkg_req_dict[pkg_name][0].split('@')[1])
return
process = subprocess.run([sys.executable, '-m', 'pip', 'install'] + pkg_req_dict[pkg_name][0].split(), text=True, capture_output=True)#-m makes the pip to work as module inside env, not the system pip!
log.info(process.stdout)
if process.returncode != 0: raise ImportError(f'Failed to install package: {pkg_name}\n{process.stderr}')
def reinstall_pkg(pkg_name):
log.info(f'Uninstalling {pkg_name}...')
process = subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', pkg_name], text=True, capture_output=True)#-m makes the pip to work as module inside env, not the system pip!
log.info(process.stdout)
if process.returncode != 0: raise ImportError(f'Failed to uninstall package: {pkg_name}\n{process.stderr}')
install_pkg(pkg_name)
def install_import(pkg_name, import_path=None, from_module=None):
"""
pkg_name: the name of the package to install (e.g., "beautifulsoup4)
import_path: full module path to import (e.g., "bs4.BeautifulSoup")
from_module: if set, return only the object from module (e.g., BeautifulSoup class)
> install_import('hydra-core', 'hydra') #Importing a submodule/class/function: from bs4 import BeautifulSoup
> BeautifulSoup = install_and_import('beautifulsoup4', 'bs4', 'BeautifulSoup')
> soup = BeautifulSoup('<html><body><p>Hello</p></body></html>', 'html.parser')
> print(soup.p.text) # -> "Hello"
"""
import_path = import_path or pkg_name
try:
module = importlib.import_module(import_path)
if not os.path.realpath(module.__file__).startswith(os.path.realpath(os.getcwd())): # bypass those internally import like evl.metric, as in Adila submodule
if(pkg_req_dict[pkg_name][1] != '0.0.0' and not is_version_equal(version(pkg_name), pkg_req_dict[pkg_name][1])):
log.info(f'{textcolor["yellow"]}Version mismatch detected. {pkg_name} version {version(pkg_name)} is installed, but {pkg_req_dict[pkg_name][1]} is required.{textcolor["reset"]}')
reinstall_pkg(pkg_name)
module = importlib.import_module(import_path)
except ImportError:
log.info(f'{import_path} not found.')
install_pkg(pkg_name)
module = importlib.import_module(import_path)
if from_module: return getattr(module, from_module)
return module
# no caching across processes when multiprocessing.pool(), so each process one web get if package not found!
def wget_import(import_path, url):# url = "https://raw.githubusercontent.com/username/other-repo/main/mymodule.py"
import sys, urllib.request, tempfile, os
with urllib.request.urlopen(url) as f: source_code = f.read().decode("utf-8")
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".py") as temp_file:
temp_file.write(source_code)
temp_path = temp_file.name
module_name = import_path #os.path.splitext(os.path.basename(url))[0]
spec = importlib.util.spec_from_file_location(module_name, temp_path)
if spec is None: raise ImportError(f'Failed to import module {module_name} from {url}')
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
if temp_file.name in locals() and os.path.exists(temp_file.name): os.remove(temp_file.name)
return module
pkg_req_dict = {}
def get_req_dict(req_file):
"""
Generates a dictionary of packages from a requirements file.
The keys are package names, and the values are tuples of (line, version).
req_file: path to the requirements file
"""
global pkg_req_dict
if pkg_req_dict: return pkg_req_dict
def extract_package_info_from_line(line):
""" line: a string from the requirements file that starts with "#$" """
line = line[2:] # Remove the "#$"
line = line.split("#")[0].strip() # Remove comments
package_name = "([-A-Za-z0-9_\.]+)"
comp = "(==|!=|<=|>=|<|>|~=|===)"
ver_num = "([0-9]+[0-9\.\*]*)"
out = []
for pkg in re.findall(f"{package_name}[\s]*{comp}[\s]*{ver_num}", line):
if '@' in line: line = line.replace('==' + pkg[2], '') # 'fairsearchcore==1.0.4@git...' >> fairsearchcore@git...
out.append((pkg[0], (line, pkg[2])))
return out # [(package_name, (line, ver_num)), ...])]
with open(req_file, 'r') as f: pkg_req_dict = dict(chain.from_iterable(map(lambda line: extract_package_info_from_line(line), filter(lambda x: x.startswith("#$"), f.readlines()))))
# log.info(f'Required packages: {pkg_req_dict}')
get_req_dict('../requirements.txt')
def set_seed(seed, torch=None):
if seed is None: return
random.seed(seed)
numpy.random.seed(seed)
if torch:
torch.manual_seed(seed)
#torch.use_deterministic_algorithms(True) #RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation .torch.nn.functional.leaky_relu is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if multiple GPUs
torch.backends.cudnn.deterministic = True # in cuDNN
torch.backends.cudnn.benchmark = False
def cfg2str(cfg): return '.'.join([f'{k}{v}' for k, v in OmegaConf.to_container(cfg, resolve=True).items()]) if cfg else ''
def str2cfg(s): #dot seperated kv, e.g., x1.y2.z3 --> x:1 y:2 z:3
items = s.split('.')
config = {}
for item in items:
key = ''.join(filter(str.isalpha, item))
value = ''.join(filter(str.isdigit, item))
config[key] = int(value) if value.isdigit() else value
return OmegaConf.create(config)
textcolor = {
'blue': '\033[94m',
'green': '\033[92m',
'yellow': '\033[93m',
'red': '\033[91m',
'magenta':'\033[95m',
'cyan': '\033[96m',
'reset': '\033[0m',
}