|
1 | 1 | import base64 |
2 | 2 | import functools |
3 | 3 | import hashlib |
| 4 | +import importlib.metadata |
4 | 5 | import logging |
5 | 6 | import os |
| 7 | +import platform |
| 8 | +import socket |
| 9 | +import subprocess |
6 | 10 | from dataclasses import dataclass, field |
7 | 11 | from datetime import datetime |
8 | 12 | from getpass import getpass |
9 | 13 | from itertools import product |
| 14 | +from multiprocessing.shared_memory import SharedMemory |
10 | 15 | from typing import Any, Dict, List, Tuple |
11 | | -import subprocess |
12 | | -import importlib.metadata |
| 16 | + |
13 | 17 | import datajoint as dj |
14 | 18 | import numpy as np |
15 | | -from scipy import ndimage |
16 | | -import platform |
17 | | -import socket |
| 19 | + |
| 20 | +try: |
| 21 | + import yaml |
| 22 | + IMPORT_YALM = True |
| 23 | +except ImportError: |
| 24 | + IMPORT_YALM = False |
| 25 | + |
| 26 | +try: |
| 27 | + from scipy import ndimage |
| 28 | + IMPORT_SCIPY = True |
| 29 | +except ImportError: |
| 30 | + IMPORT_SCIPY = False |
18 | 31 |
|
19 | 32 | log = logging.getLogger(__name__) |
20 | 33 |
|
@@ -107,6 +120,11 @@ def pol2cart(phi, rho): |
107 | 120 | y = rho * np.sin(phi) |
108 | 121 | return (x, y) |
109 | 122 |
|
| 123 | + if not globals()["IMPORT_SCIPY"]: |
| 124 | + raise ImportError( |
| 125 | + "you need to install the scipy: sudo pip3 install scipy" |
| 126 | + ) |
| 127 | + |
110 | 128 | params = dict( |
111 | 129 | {"center_x": 0, "center_y": 0, "method": "index"}, **kwargs |
112 | 130 | ) # center_x, center_y points in normalized x coordinates from center |
@@ -458,3 +476,73 @@ def get_environment_info() -> Dict[str, Any]: |
458 | 476 | "hostname": socket.gethostname(), |
459 | 477 | "username": os.getlogin() if hasattr(os, "getlogin") else "unknown", |
460 | 478 | } |
| 479 | + |
| 480 | + |
| 481 | +def read_yalm(path: str, filename: str, variable: str) -> Any: |
| 482 | + """ |
| 483 | + Read a YAML file and return a specific variable. |
| 484 | +
|
| 485 | + Parameters: |
| 486 | + path (str): The path to the directory containing the file. |
| 487 | + filename (str): The name of the YAML file. |
| 488 | + variable (str): The name of the variable to retrieve from the YAML file. |
| 489 | +
|
| 490 | + Returns: |
| 491 | + Any: The value of the specified variable from the YAML file. |
| 492 | +
|
| 493 | + Raises: |
| 494 | + FileNotFoundError: If the specified file does not exist. |
| 495 | + KeyError: If the specified variable is not found in the YAML file. |
| 496 | + """ |
| 497 | + if not globals()["IMPORT_YALM"]: |
| 498 | + raise ImportError( |
| 499 | + "you need to install the skvideo: sudo pip3 install PyYAML" |
| 500 | + ) |
| 501 | + |
| 502 | + file_path = os.path.join(path, filename) |
| 503 | + if os.path.exists(file_path): |
| 504 | + with open(file_path, "r", encoding="UTF-8") as stream: |
| 505 | + file_yaml = yaml.safe_load(stream) |
| 506 | + try: |
| 507 | + return file_yaml[variable] |
| 508 | + except KeyError as exc: |
| 509 | + raise KeyError(f"The variable '{variable}' is not found in the YAML file.") from exc |
| 510 | + else: |
| 511 | + raise FileNotFoundError(f"There is no file '{filename}' in directory: '{path}'") |
| 512 | + |
| 513 | +def shared_memory_array(name: str, rows_len: int, columns_len: int, dtype: str = "float32") -> tuple: |
| 514 | + """ |
| 515 | + Creates or retrieves a shared memory array. |
| 516 | +
|
| 517 | + Parameters: |
| 518 | + name (str): Name of the shared memory. |
| 519 | + rows_len (int): Number of rows in the array. |
| 520 | + columns_len (int): Number of columns in the array. |
| 521 | + dtype (str, optional): Data type of the array. Defaults to "float32". |
| 522 | +
|
| 523 | + Returns: |
| 524 | + tuple(numpy.ndarray, multiprocessing.shared_memory.SharedMemory): |
| 525 | + Shared memory array and SharedMemory object. |
| 526 | + dict with all the informations about the shared memory |
| 527 | + """ |
| 528 | + try: |
| 529 | + dtype_obj = np.dtype(dtype) |
| 530 | + bytes_per_item = dtype_obj.itemsize |
| 531 | + n_bytes = rows_len * columns_len * bytes_per_item |
| 532 | + |
| 533 | + # Create or retrieve the shared memory |
| 534 | + sm = SharedMemory(name=name, create=True, size=n_bytes) |
| 535 | + except FileExistsError: |
| 536 | + # Shared memory already exists, retrieve it |
| 537 | + sm = SharedMemory(name=name, create=False, size=n_bytes) |
| 538 | + except Exception as e: |
| 539 | + raise RuntimeError('Error creating/retrieving shared memory: ' + str(e)) from e |
| 540 | + |
| 541 | + # Create a numpy array that uses the shared memory |
| 542 | + shared_array = np.ndarray((rows_len, columns_len), dtype=dtype_obj, buffer=sm.buf) |
| 543 | + shared_array.fill(0) |
| 544 | + conf: Dict = {"name": "pose", |
| 545 | + "shape": (rows_len, columns_len), |
| 546 | + "dtype": dtype_obj} |
| 547 | + |
| 548 | + return shared_array, sm, conf |
0 commit comments