4
4
5
5
from pydantic import StrictStr
6
6
from pymilvus import (
7
- Collection ,
8
7
CollectionSchema ,
9
8
DataType ,
10
9
FieldSchema ,
20
19
)
21
20
from feast .infra .online_stores .online_store import OnlineStore
22
21
from feast .infra .online_stores .vector_store import VectorStoreConfig
23
- from feast .protos .feast .core .InfraObject_pb2 import InfraObject as InfraObjectProto
24
22
from feast .protos .feast .core .Registry_pb2 import Registry as RegistryProto
25
23
from feast .protos .feast .types .EntityKey_pb2 import EntityKey as EntityKeyProto
26
24
from feast .protos .feast .types .Value_pb2 import Value as ValueProto
27
25
from feast .repo_config import FeastConfigBaseModel , RepoConfig
28
26
from feast .type_map import (
29
27
PROTO_VALUE_TO_VALUE_TYPE_MAP ,
28
+ VALUE_TYPE_TO_PROTO_VALUE_MAP ,
30
29
feast_value_type_to_python_type ,
31
30
)
32
31
from feast .types import (
35
34
ComplexFeastType ,
36
35
PrimitiveFeastType ,
37
36
ValueType ,
37
+ from_feast_type ,
38
38
)
39
39
from feast .utils import (
40
40
_serialize_vector_to_float_list ,
@@ -146,9 +146,7 @@ def _get_or_create_collection(
146
146
collection_name = _table_id (config .project , table )
147
147
if collection_name not in self ._collections :
148
148
# Create a composite key by combining entity fields
149
- composite_key_name = (
150
- "_" .join ([field .name for field in table .entity_columns ]) + "_pk"
151
- )
149
+ composite_key_name = _get_composite_key_name (table )
152
150
153
151
fields = [
154
152
FieldSchema (
@@ -251,9 +249,8 @@ def online_write_batch(
251
249
).hex ()
252
250
# to recover the entity key just run:
253
251
# deserialize_entity_key(bytes.fromhex(entity_key_str), entity_key_serialization_version=3)
254
- composite_key_name = (
255
- "_" .join ([str (value ) for value in entity_key .join_keys ]) + "_pk"
256
- )
252
+ composite_key_name = _get_composite_key_name (table )
253
+
257
254
timestamp_int = int (to_naive_utc (timestamp ).timestamp () * 1e6 )
258
255
created_ts_int = (
259
256
int (to_naive_utc (created_ts ).timestamp () * 1e6 ) if created_ts else 0
@@ -293,8 +290,133 @@ def online_read(
293
290
table : FeatureView ,
294
291
entity_keys : List [EntityKeyProto ],
295
292
requested_features : Optional [List [str ]] = None ,
293
+ full_feature_names : bool = False ,
296
294
) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
297
- raise NotImplementedError
295
+ self .client = self ._connect (config )
296
+ collection_name = _table_id (config .project , table )
297
+ collection = self ._get_or_create_collection (config , table )
298
+
299
+ composite_key_name = _get_composite_key_name (table )
300
+
301
+ output_fields = (
302
+ [composite_key_name ]
303
+ + (requested_features if requested_features else [])
304
+ + ["created_ts" , "event_ts" ]
305
+ )
306
+ assert all (
307
+ field in [f ["name" ] for f in collection ["fields" ]]
308
+ for field in output_fields
309
+ ), (
310
+ f"field(s) [{ [field for field in output_fields if field not in [f ['name' ] for f in collection ['fields' ]]]} ] not found in collection schema"
311
+ )
312
+ composite_entities = []
313
+ for entity_key in entity_keys :
314
+ entity_key_str = serialize_entity_key (
315
+ entity_key ,
316
+ entity_key_serialization_version = config .entity_key_serialization_version ,
317
+ ).hex ()
318
+ composite_entities .append (entity_key_str )
319
+
320
+ query_filter_for_entities = (
321
+ f"{ composite_key_name } in ["
322
+ + ", " .join ([f"'{ e } '" for e in composite_entities ])
323
+ + "]"
324
+ )
325
+ self .client .load_collection (collection_name )
326
+ results = self .client .query (
327
+ collection_name = collection_name ,
328
+ filter = query_filter_for_entities ,
329
+ output_fields = output_fields ,
330
+ )
331
+ # Group hits by composite key.
332
+ grouped_hits : Dict [str , Any ] = {}
333
+ for hit in results :
334
+ key = hit .get (composite_key_name )
335
+ grouped_hits .setdefault (key , []).append (hit )
336
+
337
+ # Map the features to their Feast types.
338
+ feature_name_feast_primitive_type_map = {
339
+ f .name : f .dtype for f in table .features
340
+ }
341
+ # Build a dictionary mapping composite key -> (res_ts, res)
342
+ results_dict : Dict [
343
+ str , Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]
344
+ ] = {}
345
+
346
+ # here we need to map the data stored as characters back into the protobuf value
347
+ for hit in results :
348
+ key = hit .get (composite_key_name )
349
+ # Only take one hit per composite key (adjust if you need aggregation)
350
+ if key not in results_dict :
351
+ res = {}
352
+ res_ts = None
353
+ for field in output_fields :
354
+ val = ValueProto ()
355
+ field_value = hit .get (field , None )
356
+ if field_value is None and ":" in field :
357
+ _ , field_short = field .split (":" , 1 )
358
+ field_value = hit .get (field_short )
359
+
360
+ if field in ["created_ts" , "event_ts" ]:
361
+ res_ts = datetime .fromtimestamp (field_value / 1e6 )
362
+ elif field == composite_key_name :
363
+ # We do not return the composite key value
364
+ pass
365
+ else :
366
+ feature_feast_primitive_type = (
367
+ feature_name_feast_primitive_type_map .get (
368
+ field , PrimitiveFeastType .INVALID
369
+ )
370
+ )
371
+ feature_fv_dtype = from_feast_type (feature_feast_primitive_type )
372
+ proto_attr = VALUE_TYPE_TO_PROTO_VALUE_MAP .get (feature_fv_dtype )
373
+ if proto_attr :
374
+ if proto_attr == "bytes_val" :
375
+ setattr (val , proto_attr , field_value .encode ())
376
+ elif proto_attr in [
377
+ "int32_val" ,
378
+ "int64_val" ,
379
+ "float_val" ,
380
+ "double_val" ,
381
+ ]:
382
+ setattr (
383
+ val ,
384
+ proto_attr ,
385
+ type (getattr (val , proto_attr ))(field_value ),
386
+ )
387
+ elif proto_attr in [
388
+ "int32_list_val" ,
389
+ "int64_list_val" ,
390
+ "float_list_val" ,
391
+ "double_list_val" ,
392
+ ]:
393
+ setattr (
394
+ val ,
395
+ proto_attr ,
396
+ list (
397
+ map (
398
+ type (getattr (val , proto_attr )).__args__ [0 ],
399
+ field_value ,
400
+ )
401
+ ),
402
+ )
403
+ else :
404
+ setattr (val , proto_attr , field_value )
405
+ else :
406
+ raise ValueError (
407
+ f"Unsupported ValueType: { feature_feast_primitive_type } with feature view value { field_value } for feature { field } with value { field_value } "
408
+ )
409
+ # res[field] = val
410
+ key_to_use = field .split (":" , 1 )[- 1 ] if ":" in field else field
411
+ res [key_to_use ] = val
412
+ results_dict [key ] = (res_ts , res if res else None )
413
+
414
+ # Map the results back into a list matching the original order of composite_keys.
415
+ result_list = [
416
+ results_dict .get (key , (None , None )) for key in composite_entities
417
+ ]
418
+
419
+ return result_list
298
420
299
421
def update (
300
422
self ,
@@ -362,11 +484,7 @@ def retrieve_online_documents_v2(
362
484
"params" : {"nprobe" : 10 },
363
485
}
364
486
365
- composite_key_name = (
366
- "_" .join ([str (field .name ) for field in table .entity_columns ]) + "_pk"
367
- )
368
- # features_str = ", ".join([f"'{f}'" for f in requested_features])
369
- # expr = f" && feature_name in [{features_str}]"
487
+ composite_key_name = _get_composite_key_name (table )
370
488
371
489
output_fields = (
372
490
[composite_key_name ]
@@ -452,6 +570,10 @@ def _table_id(project: str, table: FeatureView) -> str:
452
570
return f"{ project } _{ table .name } "
453
571
454
572
573
+ def _get_composite_key_name (table : FeatureView ) -> str :
574
+ return "_" .join ([field .name for field in table .entity_columns ]) + "_pk"
575
+
576
+
455
577
def _extract_proto_values_to_dict (
456
578
input_dict : Dict [str , Any ],
457
579
vector_cols : List [str ],
@@ -462,6 +584,13 @@ def _extract_proto_values_to_dict(
462
584
for k in PROTO_VALUE_TO_VALUE_TYPE_MAP .keys ()
463
585
if k is not None and "list" in k and "string" not in k
464
586
]
587
+ numeric_types = [
588
+ "double_val" ,
589
+ "float_val" ,
590
+ "int32_val" ,
591
+ "int64_val" ,
592
+ "bool_val" ,
593
+ ]
465
594
output_dict = {}
466
595
for feature_name , feature_values in input_dict .items ():
467
596
for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP :
@@ -475,10 +604,18 @@ def _extract_proto_values_to_dict(
475
604
else :
476
605
vector_values = getattr (feature_values , proto_val_type ).val
477
606
else :
478
- if serialize_to_string and proto_val_type != "string_val" :
607
+ if (
608
+ serialize_to_string
609
+ and proto_val_type not in ["string_val" ] + numeric_types
610
+ ):
479
611
vector_values = feature_values .SerializeToString ().decode ()
480
612
else :
481
- vector_values = getattr (feature_values , proto_val_type )
613
+ if not isinstance (feature_values , str ):
614
+ vector_values = str (
615
+ getattr (feature_values , proto_val_type )
616
+ )
617
+ else :
618
+ vector_values = getattr (feature_values , proto_val_type )
482
619
output_dict [feature_name ] = vector_values
483
620
else :
484
621
if serialize_to_string :
@@ -487,39 +624,3 @@ def _extract_proto_values_to_dict(
487
624
output_dict [feature_name ] = feature_values
488
625
489
626
return output_dict
490
-
491
-
492
- class MilvusTable (InfraObject ):
493
- """
494
- A Milvus collection managed by Feast.
495
-
496
- Attributes:
497
- host: The host of the Milvus server.
498
- port: The port of the Milvus server.
499
- name: The name of the collection.
500
- """
501
-
502
- host : str
503
- port : int
504
-
505
- def __init__ (self , host : str , port : int , name : str ):
506
- super ().__init__ (name )
507
- self .host = host
508
- self .port = port
509
- self ._connect ()
510
-
511
- def _connect (self ):
512
- raise NotImplementedError
513
-
514
- def to_infra_object_proto (self ) -> InfraObjectProto :
515
- # Implement serialization if needed
516
- raise NotImplementedError
517
-
518
- def update (self ):
519
- # Implement update logic if needed
520
- raise NotImplementedError
521
-
522
- def teardown (self ):
523
- collection = Collection (name = self .name )
524
- if collection .exists ():
525
- collection .drop ()
0 commit comments