Skip to content

Commit e04ebc1

Browse files
committed
add helper function for the dlc interface
1 parent 435a189 commit e04ebc1

File tree

1 file changed

+93
-5
lines changed

1 file changed

+93
-5
lines changed

src/ethopy/utils/helper_functions.py

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,33 @@
11
import base64
22
import functools
33
import hashlib
4+
import importlib.metadata
45
import logging
56
import os
7+
import platform
8+
import socket
9+
import subprocess
610
from dataclasses import dataclass, field
711
from datetime import datetime
812
from getpass import getpass
913
from itertools import product
14+
from multiprocessing.shared_memory import SharedMemory
1015
from typing import Any, Dict, List, Tuple
11-
import subprocess
12-
import importlib.metadata
16+
1317
import datajoint as dj
1418
import numpy as np
15-
from scipy import ndimage
16-
import platform
17-
import socket
19+
20+
try:
21+
import yaml
22+
IMPORT_YALM = True
23+
except ImportError:
24+
IMPORT_YALM = False
25+
26+
try:
27+
from scipy import ndimage
28+
IMPORT_SCIPY = True
29+
except ImportError:
30+
IMPORT_SCIPY = False
1831

1932
log = logging.getLogger(__name__)
2033

@@ -107,6 +120,11 @@ def pol2cart(phi, rho):
107120
y = rho * np.sin(phi)
108121
return (x, y)
109122

123+
if not globals()["IMPORT_SCIPY"]:
124+
raise ImportError(
125+
"you need to install the scipy: sudo pip3 install scipy"
126+
)
127+
110128
params = dict(
111129
{"center_x": 0, "center_y": 0, "method": "index"}, **kwargs
112130
) # center_x, center_y points in normalized x coordinates from center
@@ -458,3 +476,73 @@ def get_environment_info() -> Dict[str, Any]:
458476
"hostname": socket.gethostname(),
459477
"username": os.getlogin() if hasattr(os, "getlogin") else "unknown",
460478
}
479+
480+
481+
def read_yalm(path: str, filename: str, variable: str) -> Any:
482+
"""
483+
Read a YAML file and return a specific variable.
484+
485+
Parameters:
486+
path (str): The path to the directory containing the file.
487+
filename (str): The name of the YAML file.
488+
variable (str): The name of the variable to retrieve from the YAML file.
489+
490+
Returns:
491+
Any: The value of the specified variable from the YAML file.
492+
493+
Raises:
494+
FileNotFoundError: If the specified file does not exist.
495+
KeyError: If the specified variable is not found in the YAML file.
496+
"""
497+
if not globals()["IMPORT_YALM"]:
498+
raise ImportError(
499+
"you need to install the skvideo: sudo pip3 install PyYAML"
500+
)
501+
502+
file_path = os.path.join(path, filename)
503+
if os.path.exists(file_path):
504+
with open(file_path, "r", encoding="UTF-8") as stream:
505+
file_yaml = yaml.safe_load(stream)
506+
try:
507+
return file_yaml[variable]
508+
except KeyError as exc:
509+
raise KeyError(f"The variable '{variable}' is not found in the YAML file.") from exc
510+
else:
511+
raise FileNotFoundError(f"There is no file '{filename}' in directory: '{path}'")
512+
513+
def shared_memory_array(name: str, rows_len: int, columns_len: int, dtype: str = "float32") -> tuple:
514+
"""
515+
Creates or retrieves a shared memory array.
516+
517+
Parameters:
518+
name (str): Name of the shared memory.
519+
rows_len (int): Number of rows in the array.
520+
columns_len (int): Number of columns in the array.
521+
dtype (str, optional): Data type of the array. Defaults to "float32".
522+
523+
Returns:
524+
tuple(numpy.ndarray, multiprocessing.shared_memory.SharedMemory):
525+
Shared memory array and SharedMemory object.
526+
dict with all the informations about the shared memory
527+
"""
528+
try:
529+
dtype_obj = np.dtype(dtype)
530+
bytes_per_item = dtype_obj.itemsize
531+
n_bytes = rows_len * columns_len * bytes_per_item
532+
533+
# Create or retrieve the shared memory
534+
sm = SharedMemory(name=name, create=True, size=n_bytes)
535+
except FileExistsError:
536+
# Shared memory already exists, retrieve it
537+
sm = SharedMemory(name=name, create=False, size=n_bytes)
538+
except Exception as e:
539+
raise RuntimeError('Error creating/retrieving shared memory: ' + str(e)) from e
540+
541+
# Create a numpy array that uses the shared memory
542+
shared_array = np.ndarray((rows_len, columns_len), dtype=dtype_obj, buffer=sm.buf)
543+
shared_array.fill(0)
544+
conf: Dict = {"name": "pose",
545+
"shape": (rows_len, columns_len),
546+
"dtype": dtype_obj}
547+
548+
return shared_array, sm, conf

0 commit comments

Comments
 (0)