diff --git a/app.py b/app.py index a072b6b..cb3d2f2 100644 --- a/app.py +++ b/app.py @@ -9,6 +9,8 @@ ) from submodules.model import session +import traceback + app = FastAPI() @@ -17,6 +19,9 @@ async def handle_db_session(request: Request, call_next): session_token = general.get_ctx_token() try: response = await call_next(request) + except Exception: + print(traceback.format_exc(), flush=True) + response = None finally: general.remove_and_refresh_session(session_token) @@ -66,6 +71,7 @@ class MostSimilarByEmbeddingRequest(BaseModel): att_filter: Optional[List[Dict[str, Any]]] = None threshold: Optional[Union[float, int]] = None question: Optional[str] = None + user_id: Optional[str] = None @app.post("/most_similar_by_embedding") @@ -99,6 +105,7 @@ def most_similar_by_embedding( request.att_filter, request.threshold, include_scores, + request.user_id, ) if request.question: diff --git a/neural_search/util.py b/neural_search/util.py index f58b3da..093a763 100644 --- a/neural_search/util.py +++ b/neural_search/util.py @@ -10,10 +10,18 @@ embedding, record_label_association, record, + project, + user, ) -from submodules.model.enums import EmbeddingPlatform, LabelSource +from submodules.model.cognition_objects import group_member +from submodules.model.integration_objects.helper import ( + REFINERY_ATTRIBUTE_ACCESS_GROUPS, + REFINERY_ATTRIBUTE_ACCESS_USERS, +) +from submodules.model.enums import EmbeddingPlatform, LabelSource, UserRoles from .similarity_threshold import SimilarityThreshold, NO_THRESHOLD_INDICATOR +import traceback port = int(os.environ["QDRANT_PORT"]) qdrant_client = QdrantClient(host="qdrant", port=port, timeout=60) @@ -48,9 +56,25 @@ def most_similar_by_embedding( att_filter: Optional[List[Dict[str, Any]]] = None, threshold: Optional[float] = None, include_scores: bool = False, + user_id: Optional[str] = None, ) -> List[str]: if not is_filter_valid_for_embedding(project_id, embedding_id, att_filter): return [] + if project.check_access_management_active(project_id): + if not user_id: + return [] + requesting_user = user.get(user_id) + if not requesting_user: + return [] + if requesting_user.role != UserRoles.ENGINEER.value: + check_access = True + group_members = group_member.get_by_user_id(user_id) + group_ids = [str(group_member.group_id) for group_member in group_members] + else: + check_access = False + else: + check_access = False + tmp_limit = limit has_sub_key = embedding.has_sub_key(project_id, embedding_id) if has_sub_key: @@ -66,14 +90,20 @@ def most_similar_by_embedding( elif similarity_threshold == NO_THRESHOLD_INDICATOR: similarity_threshold = None try: + _filter = __build_filter(att_filter) + if check_access: + _filter = __add_access_management_filter(_filter, group_ids, user_id) + search_result = qdrant_client.search( collection_name=embedding_id, query_vector=query_vector, - query_filter=__build_filter(att_filter), + query_filter=_filter, limit=tmp_limit, score_threshold=similarity_threshold, ) - except Exception: + except Exception as e: + print(f"Error during search in Qdrant: {e}", flush=True) + print(traceback.format_exc(), flush=True) return [] if include_scores: @@ -118,39 +148,61 @@ def __is_label_filter(key: str) -> bool: return parts[0] == LABELS_QDRANT -def __build_filter(att_filter: List[Dict[str, Any]]) -> models.Filter: - if att_filter is None or len(att_filter) == 0: +def __build_filter(att_filter: List[Dict[str, Any]]) -> Optional[models.Filter]: + if not att_filter: return None - must = [__build_filter_item(filter_item) for filter_item in att_filter] + must = [__build_filter_item(item) for item in att_filter] return models.Filter(must=must) +def __add_access_management_filter( + base_filter: Optional[models.Filter], group_ids: List[str], user_id: str +) -> models.Filter: + access_conditions = [ + models.FieldCondition( + key=REFINERY_ATTRIBUTE_ACCESS_GROUPS, + match=models.MatchAny(any=group_ids), + ), + models.FieldCondition( + key=REFINERY_ATTRIBUTE_ACCESS_USERS, + match=models.MatchValue(value=user_id), + ), + ] + + if base_filter is None: + return models.Filter(should=access_conditions) + + return models.Filter( + must=base_filter.must or [], + should=access_conditions, + ) + + def __build_filter_item(filter_item: Dict[str, Any]) -> models.FieldCondition: - if isinstance(filter_item["value"], list): - if filter_item.get("type") == "between": - return models.FieldCondition( - key=filter_item["key"], - range=models.Range( - gte=filter_item["value"][0], - lte=filter_item["value"][1], - ), - ) - else: - should = [ - models.FieldCondition( - key=filter_item["key"], match=models.MatchValue(value=value) - ) - for value in filter_item["value"] - ] - return models.Filter(should=should) - else: + key = filter_item["key"] + value = filter_item["value"] + typ = filter_item.get("type") + + # BETWEEN + if isinstance(value, list) and typ == "between": + return models.FieldCondition( + key=key, + range=models.Range(gte=value[0], lte=value[1]), + ) + + # IN (...) + if isinstance(value, list): return models.FieldCondition( - key=filter_item["key"], - match=models.MatchValue( - value=filter_item["value"], - ), + key=key, + match=models.MatchAny(any=value), ) + # = single value + return models.FieldCondition( + key=key, + match=models.MatchValue(value=value), + ) + def recreate_collection(project_id: str, embedding_id: str) -> int: embedding_item = embedding.get(project_id, embedding_id) diff --git a/submodules/model b/submodules/model index 4cdfbd2..19c0a4d 160000 --- a/submodules/model +++ b/submodules/model @@ -1 +1 @@ -Subproject commit 4cdfbd240114f22ba493d9a552b812499e0c5298 +Subproject commit 19c0a4d25233fa0a7d5c4ee5377954c0594d2750