Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
i += 1
if i >= 60:
i = 0
daemon.reset_session_token_in_thread()
general.remove_and_refresh_session(None, True)
if tokenization.is_doc_bin_creation_running_or_queued(project_id):
time.sleep(2)
continue
Expand All @@ -211,7 +211,7 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
break
if i >= 60:
i = 0
daemon.reset_session_token_in_thread()
general.remove_and_refresh_session(None, True)

current_att_id = attribute_ids[0]
current_att = attribute.get(project_id, current_att_id)
Expand Down
4 changes: 2 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

@pytest.fixture(scope="session", autouse=True)
def database_session() -> Iterator[None]:
session_token = general.get_ctx_token()
general.get_ctx_token()
yield
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()


@pytest.fixture(scope="session")
Expand Down
17 changes: 8 additions & 9 deletions controller/attribute/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __add_running_id(
attribute_name: str,
for_retokenization: bool = True,
):
session_token = general.get_ctx_token()
general.get_ctx_token()
attribute.add_running_id(
project_id, attribute_name, for_retokenization, with_commit=True
)
Expand All @@ -231,7 +231,7 @@ def __add_running_id(
"project_id": str(project_id),
},
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()


def calculate_user_attribute_all_records(
Expand Down Expand Up @@ -301,8 +301,7 @@ def __calculate_user_attribute_all_records(
attribute_id: str,
include_rats: bool,
) -> None:
session_token = general.get_ctx_token()

general.get_ctx_token()
try:
calculated_attributes = util.run_attribute_calculation_exec_env(
attribute_id=attribute_id, project_id=project_id, doc_bin="docbin_full"
Expand All @@ -320,7 +319,7 @@ def __calculate_user_attribute_all_records(
attribute_id=attribute_id,
log="Attribute calculation failed",
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()
return

util.add_log_to_attribute_logs(
Expand All @@ -345,7 +344,7 @@ def __calculate_user_attribute_all_records(
attribute_id=attribute_id,
log="Writing to the database failed.",
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()
return
util.add_log_to_attribute_logs(project_id, attribute_id, "Finished writing.")

Expand Down Expand Up @@ -385,7 +384,7 @@ def __calculate_user_attribute_all_records(
attribute_id=attribute_id,
log="Writing to the database failed.",
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()
return

else:
Expand All @@ -401,7 +400,7 @@ def __calculate_user_attribute_all_records(
attribute_id=attribute_id,
log="Writing to the database failed.",
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()
return
util.set_progress(project_id, attribute_item, 1.0)
attribute.update(
Expand All @@ -415,7 +414,7 @@ def __calculate_user_attribute_all_records(
notification.send_organization_update(
project_id, f"calculate_attribute:finished:{attribute_id}"
)
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()


def __notify_attribute_calculation_failed(
Expand Down
6 changes: 3 additions & 3 deletions controller/attribute/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def read_container_logs_thread(
attribute_id: str,
docker_container: Any,
) -> None:
ctx_token = general.get_ctx_token()
general.get_ctx_token()
# needs to be refetched since it is not thread safe
attribute_item = attribute.get(project_id, attribute_id)
previous_progress = -1
Expand All @@ -502,7 +502,7 @@ def read_container_logs_thread(
time.sleep(1)
c += 1
if c > 100:
ctx_token = general.remove_and_refresh_session(ctx_token, True)
general.remove_and_refresh_session(None, True)
attribute_item = attribute.get(project_id, attribute_id)
if not attribute_item:
break
Expand Down Expand Up @@ -544,7 +544,7 @@ def read_container_logs_thread(
continue
previous_progress = last_entry
set_progress(project_id, attribute_item, last_entry * 0.8 + 0.05)
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()


def set_progress(
Expand Down
10 changes: 5 additions & 5 deletions controller/payload/payload_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def prepare_and_run_execution_pipeline(
in_thread: bool = False,
) -> None:
if in_thread:
ctx_token = general.get_ctx_token()
general.get_ctx_token()
try:
add_file_name, input_data = prepare_input_data_for_payload(
information_source_item
Expand All @@ -123,7 +123,7 @@ def prepare_and_run_execution_pipeline(
)
finally:
if in_thread:
general.reset_ctx_token(ctx_token, True)
general.reset_ctx_token(None, True)

def prepare_input_data_for_payload(
information_source_item: InformationSource,
Expand Down Expand Up @@ -452,7 +452,7 @@ def read_container_logs_thread(
payload_id: str,
docker_container: Any,
):
ctx_token = general.get_ctx_token()
general.get_ctx_token()
# needs to be refetched since it is not thread safe
information_source_payload = information_source.get_payload(project_id, payload_id)
previous_progress = -1
Expand All @@ -462,7 +462,7 @@ def read_container_logs_thread(
time.sleep(1)
c += 1
if c > 100:
ctx_token = general.remove_and_refresh_session(ctx_token, True)
general.remove_and_refresh_session(None, True)
information_source_payload = information_source.get_payload(
project_id, payload_id
)
Expand Down Expand Up @@ -504,7 +504,7 @@ def read_container_logs_thread(
set_payload_progress(
project_id, information_source_payload, last_entry, factor=0.8
)
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()


def get_inference_dir() -> str:
Expand Down
4 changes: 2 additions & 2 deletions controller/project/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def check_in_deletion_projects() -> None:
def __check_in_deletion_projects() -> None:
# wait for startup to finish
time.sleep(2)
ctx_token = general.get_ctx_token()
general.get_ctx_token()
to_be_deleted = []
orgs = organization.get_all()
for org_item in orgs:
Expand All @@ -277,4 +277,4 @@ def __check_in_deletion_projects() -> None:
to_be_deleted.append(str(project_item.id))
for project_id in to_be_deleted:
delete_project(project_id)
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()
4 changes: 2 additions & 2 deletions controller/record/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def delete_all_records(project_id: str) -> None:


def __reupload_embeddings(project_id: str) -> None:
ctx_token = general.get_ctx_token()
general.get_ctx_token()
embeddings = embedding.get_finished_embeddings(project_id)
for e in embeddings:
embedding_manager.request_tensor_upload(project_id, str(e.id))
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()


def get_unique_values_by_attributes(project_id: str) -> Dict[str, List[str]]:
Expand Down
4 changes: 2 additions & 2 deletions controller/record_label_association/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ def create_manual_classification_label(
def __check_label_duplication_classification_and_react(
project_id: str, record_id: str, user_id: str, label_ids: List[str]
):
ctx_token = general.get_ctx_token()
general.get_ctx_token()
if check_label_duplication_classification(
project_id, record_id, user_id, label_ids
):
notification.send_organization_update(project_id, f"rla_deleted:{record_id}")
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()


def __update_label_payloads_for_neural_search(project_id: str, record_id: str):
Expand Down
6 changes: 3 additions & 3 deletions controller/transfer/project_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,12 +921,12 @@ def __post_processing_import_threaded(
user_id: str,
) -> None:
time.sleep(5)
ctx_token = general.get_ctx_token()
general.get_ctx_token()
c = 1
while True:
c += 1
if c > 12:
ctx_token = general.remove_and_refresh_session(ctx_token, True)
general.remove_and_refresh_session(None, True)
c = 1
if task_queue.get_by_tokenization(project_id):
logger.info(f"Waiting for tokenization of project {project_id}")
Expand All @@ -945,7 +945,7 @@ def __post_processing_import_threaded(
embedding_manager.request_tensor_upload(
project_id, str(embedding_ids[old_id])
)
general.remove_and_refresh_session(ctx_token)
general.remove_and_refresh_session()


def get_project_export_dump(
Expand Down
4 changes: 2 additions & 2 deletions controller/weak_supervision/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def execution_pipeline(
overwrite_default_precision: Optional[float] = None,
overwrite_weak_supervision: Optional[Dict[str, float]] = None,
):
ctx_token = general.get_ctx_token()
general.get_ctx_token()
try:
labeling_tasks = labeling_task.get_labeling_tasks_by_selected_sources(
project_id
Expand Down Expand Up @@ -106,7 +106,7 @@ def execution_pipeline(
)
raise e
finally:
general.reset_ctx_token(ctx_token)
general.reset_ctx_token()

daemon.run_without_db_token(
execution_pipeline,
Expand Down
9 changes: 4 additions & 5 deletions middleware/database_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@
from middleware import log_storage
from fast_api.routes.client_response import GENERIC_FAILURE_RESPONSE
import traceback

from submodules.model.session_wrapper import run_db_async_with_session
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


async def handle_db_session(request: Request, call_next):
info = _prepare_info(request)
session_token = general.get_ctx_token()
general.get_ctx_token()
try:
request.state.session_token = session_token
info.context = {"request": request}
request.state.info = info
request.state.parsed = {}

log_request = auth_manager.extract_state_info(request, "log_request")
log_request = await run_db_async_with_session(auth_manager.extract_state_info, request, "log_request")
length = request.headers.get("content-length")

if length and int(length) > 0:
Expand All @@ -37,7 +36,7 @@ async def handle_db_session(request: Request, call_next):
print(traceback.format_exc(), flush=True)
return GENERIC_FAILURE_RESPONSE
finally:
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()


def _prepare_info(request):
Expand Down
7 changes: 4 additions & 3 deletions middleware/log_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datetime import datetime
from submodules.model.enums import AdminLogLevel, try_parse_enum_value
from controller.auth import manager as auth_manager
from submodules.model.session_wrapper import run_db_async_with_session


__not_yet_persisted = {} # {log_path: List[Dict[str,Any]]}
Expand Down Expand Up @@ -101,7 +102,7 @@ async def set_request_data(request: Request) -> bytes:


async def log_request(request):
log_request = auth_manager.extract_state_info(request, "log_request")
log_request = await run_db_async_with_session(auth_manager.extract_state_info, request, "log_request")
log_lvl: AdminLogLevel = try_parse_enum_value(log_request, AdminLogLevel, False)
# lazy boolean resolution to avoid unnecessary calls
if (
Expand All @@ -116,11 +117,11 @@ async def log_request(request):
data = request.state.data

now = datetime.now()
org_id = auth_manager.extract_state_info(request, "organization_id")
org_id = await run_db_async_with_session(auth_manager.extract_state_info, request, "organization_id"))
log_path = f"/logs/admin/{org_id}/{now.strftime('%Y-%m-%d')}.csv"
log_entry = {
"timestamp": now.strftime("%Y-%m-%d %H:%M:%S.%f"),
"user_id": auth_manager.extract_state_info(request, "user_id"),
"user_id": await run_db_async_with_session(auth_manager.extract_state_info, request, "user_id"),
"gateway": "REFINERY",
"method": str(request.method),
"path": str(request.url.path),
Expand Down
4 changes: 2 additions & 2 deletions middleware/starlette_tmp_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def dispatch(self, request, call_next):
# fast api middleware handles these
return await call_next(request)

session_token = general.get_ctx_token()
general.get_ctx_token()
try:
response = await call_next(request)
# finally is still called even if returned response
Expand All @@ -26,4 +26,4 @@ async def dispatch(self, request, call_next):
print(traceback.format_exc(), flush=True)
return GENERIC_FAILURE_RESPONSE
finally:
general.remove_and_refresh_session(session_token)
general.remove_and_refresh_session()