Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
i += 1
if i >= 60:
i = 0
general.remove_and_refresh_session(request_new=True)
general.remove_and_refresh_session(None, True)
if tokenization.is_doc_bin_creation_running_or_queued(project_id):
time.sleep(2)
continue
Expand All @@ -233,7 +233,7 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
break
if i >= 60:
i = 0
general.remove_and_refresh_session(request_new=True)
general.remove_and_refresh_session(None, True)

current_att_id = attribute_ids[0]
current_att = attribute.get(project_id, current_att_id)
Expand Down
4 changes: 2 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

@pytest.fixture(scope="session", autouse=True)
def database_session() -> Iterator[None]:
session_token = general.get_ctx_token()
general.get_ctx_token()
yield
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()


@pytest.fixture(scope="session")
Expand Down
16 changes: 8 additions & 8 deletions controller/attribute/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __add_running_id(
attribute_name: str,
for_retokenization: bool = True,
):
session_token = general.get_ctx_token()
general.get_ctx_token()
attribute.add_running_id(
project_id, attribute_name, for_retokenization, with_commit=True
)
Expand All @@ -231,7 +231,7 @@ def __add_running_id(
"project_id": str(project_id),
},
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()


def calculate_user_attribute_missing_records(
Expand Down Expand Up @@ -301,7 +301,7 @@ def __calculate_user_attribute_missing_records(
attribute_id: str,
include_rats: bool,
) -> None:
session_token = general.get_ctx_token()
general.get_ctx_token()

all_records_count = record.count(project_id)
count_delta = record.count_missing_delta(project_id, attribute_id)
Expand Down Expand Up @@ -329,7 +329,7 @@ def __calculate_user_attribute_missing_records(
attribute_id=attribute_id,
log="Attribute calculation failed",
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()
return

util.add_log_to_attribute_logs(
Expand All @@ -354,7 +354,7 @@ def __calculate_user_attribute_missing_records(
attribute_id=attribute_id,
log="Writing to the database failed.",
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()
return
util.add_log_to_attribute_logs(project_id, attribute_id, "Finished writing.")

Expand Down Expand Up @@ -394,7 +394,7 @@ def __calculate_user_attribute_missing_records(
attribute_id=attribute_id,
log="Writing to the database failed.",
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()
return

else:
Expand All @@ -410,7 +410,7 @@ def __calculate_user_attribute_missing_records(
attribute_id=attribute_id,
log="Writing to the database failed.",
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()
return
util.set_progress(project_id, attribute_item, 1.0)
attribute.update(
Expand All @@ -424,7 +424,7 @@ def __calculate_user_attribute_missing_records(
notification.send_organization_update(
project_id, f"calculate_attribute:finished:{attribute_id}"
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()


def __notify_attribute_calculation_failed(
Expand Down
6 changes: 3 additions & 3 deletions controller/attribute/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def read_container_logs_thread(
attribute_id: str,
docker_container: Any,
) -> None:
ctx_token = general.get_ctx_token()
general.get_ctx_token()
# needs to be refetched since it is not thread safe
attribute_item = attribute.get(project_id, attribute_id)
previous_progress = -1
Expand All @@ -522,7 +522,7 @@ def read_container_logs_thread(
time.sleep(1)
c += 1
if c > 100:
ctx_token = general.remove_and_refresh_session(ctx_token, True)
general.remove_and_refresh_session(None, True)
attribute_item = attribute.get(project_id, attribute_id)
if not attribute_item:
break
Expand Down Expand Up @@ -564,7 +564,7 @@ def read_container_logs_thread(
continue
previous_progress = last_entry
set_progress(project_id, attribute_item, last_entry * 0.8 + 0.05)
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()


def set_progress(
Expand Down
10 changes: 5 additions & 5 deletions controller/payload/payload_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def prepare_and_run_execution_pipeline(
in_thread: bool = False,
) -> None:
if in_thread:
ctx_token = general.get_ctx_token()
general.get_ctx_token()
try:
add_file_name, input_data = prepare_input_data_for_payload(
information_source_item
Expand All @@ -123,7 +123,7 @@ def prepare_and_run_execution_pipeline(
)
finally:
if in_thread:
general.reset_ctx_token(ctx_token, True)
general.reset_ctx_token(None, True)

def prepare_input_data_for_payload(
information_source_item: InformationSource,
Expand Down Expand Up @@ -452,7 +452,7 @@ def read_container_logs_thread(
payload_id: str,
docker_container: Any,
):
ctx_token = general.get_ctx_token()
general.get_ctx_token()
# needs to be refetched since it is not thread safe
information_source_payload = information_source.get_payload(project_id, payload_id)
previous_progress = -1
Expand All @@ -462,7 +462,7 @@ def read_container_logs_thread(
time.sleep(1)
c += 1
if c > 100:
ctx_token = general.remove_and_refresh_session(ctx_token, True)
general.remove_and_refresh_session(None, True)
information_source_payload = information_source.get_payload(
project_id, payload_id
)
Expand Down Expand Up @@ -504,7 +504,7 @@ def read_container_logs_thread(
set_payload_progress(
project_id, information_source_payload, last_entry, factor=0.8
)
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()


def get_inference_dir() -> str:
Expand Down
4 changes: 2 additions & 2 deletions controller/project/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def check_in_deletion_projects() -> None:
def __check_in_deletion_projects() -> None:
# wait for startup to finish
time.sleep(2)
ctx_token = general.get_ctx_token()
general.get_ctx_token()
to_be_deleted = []
orgs = organization.get_all()
for org_item in orgs:
Expand All @@ -364,4 +364,4 @@ def __check_in_deletion_projects() -> None:
to_be_deleted.append(str(project_item.id))
for project_id in to_be_deleted:
delete_project(project_id)
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()
4 changes: 2 additions & 2 deletions controller/record/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ def delete_all_records(project_id: str) -> None:


def __reupload_embeddings(project_id: str) -> None:
ctx_token = general.get_ctx_token()
general.get_ctx_token()
embeddings = embedding.get_finished_embeddings(project_id)
for e in embeddings:
embedding_manager.request_tensor_upload(project_id, str(e.id))
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()


def get_unique_values_by_attributes(project_id: str) -> Dict[str, List[str]]:
Expand Down
4 changes: 2 additions & 2 deletions controller/record_label_association/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ def create_manual_classification_label(
def __check_label_duplication_classification_and_react(
project_id: str, record_id: str, user_id: str, label_ids: List[str]
):
ctx_token = general.get_ctx_token()
general.get_ctx_token()
if check_label_duplication_classification(
project_id, record_id, user_id, label_ids
):
notification.send_organization_update(project_id, f"rla_deleted:{record_id}")
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()


def __update_label_payloads_for_neural_search(project_id: str, record_id: str):
Expand Down
6 changes: 3 additions & 3 deletions controller/transfer/project_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,12 +921,12 @@ def __post_processing_import_threaded(
user_id: str,
) -> None:
time.sleep(5)
ctx_token = general.get_ctx_token()
general.get_ctx_token()
c = 1
while True:
c += 1
if c > 12:
ctx_token = general.remove_and_refresh_session(ctx_token, True)
general.remove_and_refresh_session(None, True)
c = 1
if task_queue.get_by_tokenization(project_id):
logger.info(f"Waiting for tokenization of project {project_id}")
Expand All @@ -945,7 +945,7 @@ def __post_processing_import_threaded(
embedding_manager.request_tensor_upload(
project_id, str(embedding_ids[old_id])
)
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()


def get_project_export_dump(
Expand Down
4 changes: 2 additions & 2 deletions controller/weak_supervision/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def execution_pipeline(
overwrite_default_precision: Optional[float] = None,
overwrite_weak_supervision: Optional[Dict[str, float]] = None,
):
ctx_token = general.get_ctx_token()
general.get_ctx_token()
try:
labeling_tasks = labeling_task.get_labeling_tasks_by_selected_sources(
project_id
Expand Down Expand Up @@ -106,7 +106,7 @@ def execution_pipeline(
)
raise e
finally:
general.reset_ctx_token(ctx_token)
general.reset_ctx_token()

daemon.run_without_db_token(
execution_pipeline,
Expand Down
9 changes: 4 additions & 5 deletions middleware/database_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@
from middleware import log_storage
from fast_api.routes.client_response import GENERIC_FAILURE_RESPONSE
import traceback

from submodules.model.session_wrapper import run_db_async_with_session
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


async def handle_db_session(request: Request, call_next):
info = _prepare_info(request)
session_token = general.get_ctx_token()
general.get_ctx_token()
try:
request.state.session_token = session_token
info.context = {"request": request}
request.state.info = info
request.state.parsed = {}

log_request = auth_manager.extract_state_info(request, "log_request")
log_request = await run_db_async_with_session(auth_manager.extract_state_info, request, "log_request")
length = request.headers.get("content-length")

if length and int(length) > 0:
Expand All @@ -37,7 +36,7 @@ async def handle_db_session(request: Request, call_next):
print(traceback.format_exc(), flush=True)
return GENERIC_FAILURE_RESPONSE
finally:
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()


def _prepare_info(request):
Expand Down
7 changes: 4 additions & 3 deletions middleware/log_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datetime import datetime
from submodules.model.enums import AdminLogLevel, try_parse_enum_value
from controller.auth import manager as auth_manager
from submodules.model.session_wrapper import run_db_async_with_session


__not_yet_persisted = {} # {log_path: List[Dict[str,Any]]}
Expand Down Expand Up @@ -101,7 +102,7 @@ async def set_request_data(request: Request) -> bytes:


async def log_request(request):
log_request = auth_manager.extract_state_info(request, "log_request")
log_request = await run_db_async_with_session(auth_manager.extract_state_info, request, "log_request")
log_lvl: AdminLogLevel = try_parse_enum_value(log_request, AdminLogLevel, False)
# lazy boolean resolution to avoid unnecessary calls
if (
Expand All @@ -116,11 +117,11 @@ async def log_request(request):
data = request.state.data

now = datetime.now()
org_id = auth_manager.extract_state_info(request, "organization_id")
org_id = await run_db_async_with_session(auth_manager.extract_state_info, request, "organization_id")
log_path = f"/logs/admin/{org_id}/{now.strftime('%Y-%m-%d')}.csv"
log_entry = {
"timestamp": now.strftime("%Y-%m-%d %H:%M:%S.%f"),
"user_id": auth_manager.extract_state_info(request, "user_id"),
"user_id": await run_db_async_with_session(auth_manager.extract_state_info, request, "user_id"),
"gateway": "REFINERY",
"method": str(request.method),
"path": str(request.url.path),
Expand Down
4 changes: 2 additions & 2 deletions middleware/starlette_tmp_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def dispatch(self, request, call_next):
# fast api middleware handles these
return await call_next(request)

session_token = general.get_ctx_token()
general.get_ctx_token()
try:
response = await call_next(request)
# finally is still called even if returned response
Expand All @@ -26,4 +26,4 @@ async def dispatch(self, request, call_next):
print(traceback.format_exc(), flush=True)
return GENERIC_FAILURE_RESPONSE
finally:
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()