@@ -314,12 +314,13 @@ def stage_prepared_descriptor_batch(
314314 _validate_descriptor_image_arena_base (image , arena_base , "prepared descriptor batch" )
315315 _validate_descriptor_image_batch_index (image , batch_index , "prepared descriptor batch" )
316316 _validate_descriptor_image_op_names (prepared , image )
317- _validate_descriptor_image_words (image , descriptor_base )
317+ descriptor_count = _validate_descriptor_image_words (image , descriptor_base )
318318 sequence = prepared .get ("host_runtime_sequence" )
319319 if not isinstance (sequence , Mapping ):
320320 raise ValueError ("prepared descriptor batch requires host_runtime_sequence" )
321321 _validate_sequence_mmio_preamble (prepared , sequence )
322- _validate_sequence_submission_base (sequence , descriptor_base )
322+ _validate_descriptor_writeback_preamble (prepared , image )
323+ _validate_descriptor_submission (image , sequence , descriptor_base , descriptor_count )
323324 _validate_sequence_descriptor_memory_writes (sequence , image )
324325 result = stage_host_runtime_sequence (
325326 sequence ,
@@ -418,12 +419,13 @@ def stage_prepared_descriptor_execution_batches(
418419 raise ValueError (
419420 "prepared execution batch descriptor_base does not match descriptor_stride_bytes"
420421 )
421- _validate_descriptor_image_words (image , expected_descriptor_base )
422+ descriptor_count = _validate_descriptor_image_words (image , expected_descriptor_base )
422423 sequence = batch .get ("host_runtime_sequence" )
423424 if not isinstance (sequence , Mapping ):
424425 raise ValueError ("prepared execution batch requires host_runtime_sequence" )
425426 _validate_sequence_mmio_preamble (batch , sequence )
426- _validate_sequence_submission_base (sequence , expected_descriptor_base )
427+ _validate_descriptor_writeback_preamble (batch , image )
428+ _validate_descriptor_submission (image , sequence , expected_descriptor_base , descriptor_count )
427429 _validate_sequence_descriptor_memory_writes (sequence , image )
428430 sequences .append (sequence )
429431
@@ -515,23 +517,35 @@ def _validate_descriptor_image_op_names(batch: Mapping[str, Any], image: Mapping
515517 raise ValueError ("prepared execution batch op_names do not match op_mmio_preamble" )
516518
517519
518- def _validate_descriptor_image_words (image : Mapping [str , Any ], descriptor_base : int ) -> None :
520+ def _validate_descriptor_image_words (image : Mapping [str , Any ], descriptor_base : int ) -> int :
519521 descriptor_words = image .get ("descriptor_words" )
520522 if not isinstance (descriptor_words , list ) or not descriptor_words :
521523 raise ValueError ("prepared execution batch requires descriptor_words metadata" )
524+ if len (descriptor_words ) > CommandBuffer .MAX_ENTRIES :
525+ raise ValueError ("prepared execution batch descriptor_words exceed RTL ring window" )
522526 expected : dict [int , int ] = {}
523527 for descriptor_index , words in enumerate (descriptor_words ):
524528 if not isinstance (words , list ) or len (words ) != 4 :
525529 raise ValueError ("prepared execution batch descriptor_words entry must have four words" )
526530 for word_index , word in enumerate (words ):
527531 if not isinstance (word , int ) or word < 0 or word > 0xFFFF_FFFF :
528532 raise ValueError ("prepared execution batch descriptor_words value must be a uint32" )
533+ if word_index == 0 and not word & E1NpuRuntime .DESC_FLAG_VALID_OWNER :
534+ raise ValueError (
535+ "prepared execution batch descriptor word0 missing valid_owner bit"
536+ )
537+ if word_index == 0 :
538+ _validate_descriptor_word0 (word )
529539 address = (
530540 descriptor_base + descriptor_index * CommandBuffer .DESCRIPTOR_BYTES + word_index * 4
531541 )
532542 if address > 0xFFFF_FFFF :
533543 raise ValueError ("prepared execution batch descriptor_image address exceeds uint32" )
534544 expected [address ] = word
545+ if words [0 ] & E1NpuRuntime .DESC_FLAG_WRITEBACK_REQUEST and words [2 ] & 0x3 :
546+ raise ValueError (
547+ "prepared execution batch descriptor writeback address must be aligned"
548+ )
535549 descriptor_image = image .get ("descriptor_image" )
536550 if not isinstance (descriptor_image , Mapping ) or not descriptor_image :
537551 raise ValueError ("prepared execution batch requires descriptor_image metadata" )
@@ -543,19 +557,107 @@ def _validate_descriptor_image_words(image: Mapping[str, Any], descriptor_base:
543557 materialized [parsed_address ] = value
544558 if materialized != expected :
545559 raise ValueError ("prepared execution batch descriptor_words do not match descriptor_image" )
560+ return len (descriptor_words )
561+
562+
563+ def _validate_descriptor_word0 (word0 : int ) -> None :
564+ opcode = word0 & 0xF
565+ stream_to_scratch = bool (word0 & E1NpuRuntime .DESC_FLAG_STREAM_TO_SCRATCH )
566+ writeback_request = bool (word0 & E1NpuRuntime .DESC_FLAG_WRITEBACK_REQUEST )
567+ scratch_offset = (word0 >> 16 ) & 0x3F
568+ byte_count = (word0 >> 24 ) & 0x3F
569+ if not stream_to_scratch :
570+ raise ValueError ("prepared execution batch descriptor word0 missing stream_to_scratch bit" )
571+ if byte_count == 0 or byte_count & 0x3 :
572+ raise ValueError (
573+ "prepared execution batch descriptor byte_count must be positive and aligned"
574+ )
575+ if scratch_offset & 0x3 or scratch_offset + byte_count > E1NpuRuntime .SCRATCH_BYTES :
576+ raise ValueError ("prepared execution batch descriptor stream exceeds scratchpad" )
577+ if writeback_request and opcode not in (E1NpuRuntime .OP_GEMM_S8 , E1NpuRuntime .OP_GEMM_S4 ):
578+ raise ValueError ("prepared execution batch writeback_request requires GEMM opcode" )
579+
580+
581+ def _validate_descriptor_writeback_preamble (
582+ batch : Mapping [str , Any ], image : Mapping [str , Any ]
583+ ) -> None :
584+ descriptor_words = image .get ("descriptor_words" )
585+ op_mmio_preamble = batch .get ("op_mmio_preamble" )
586+ if not isinstance (descriptor_words , list ) or not isinstance (op_mmio_preamble , list ):
587+ raise ValueError ("prepared execution batch requires descriptor_words and op_mmio_preamble" )
588+ if len (descriptor_words ) != len (op_mmio_preamble ):
589+ raise ValueError ("prepared execution batch descriptor_words count does not match op_mmio_preamble" )
590+ for words , entry in zip (descriptor_words , op_mmio_preamble , strict = True ):
591+ if not isinstance (words , list ) or not words :
592+ raise ValueError ("prepared execution batch descriptor_words entry must have four words" )
593+ word0 = words [0 ]
594+ if not isinstance (word0 , int ):
595+ raise ValueError ("prepared execution batch descriptor_words value must be a uint32" )
596+ if not word0 & E1NpuRuntime .DESC_FLAG_WRITEBACK_REQUEST :
597+ continue
598+ if not isinstance (entry , Mapping ):
599+ raise TypeError ("prepared execution batch preamble entry must be a mapping" )
600+ preamble = entry .get ("mmio_preamble" )
601+ if not isinstance (preamble , Mapping ):
602+ raise ValueError ("prepared execution batch requires mmio_preamble metadata" )
603+ cfg = preamble .get ("GEMM_CFG" )
604+ if not isinstance (cfg , int ) or cfg < 0 or cfg > 0xFFFF_FFFF :
605+ raise ValueError ("prepared execution batch GEMM_CFG must be a uint32" )
606+ writeback_bytes = (cfg & 0x3 ) * ((cfg >> 8 ) & 0x3 ) * 4
607+ if writeback_bytes == 0 :
608+ raise ValueError (
609+ "prepared execution batch writeback_request requires nonzero GEMM output"
610+ )
546611
547612
548- def _validate_sequence_submission_base (sequence : Mapping [str , Any ], expected_base : int ) -> None :
613+ def _validate_descriptor_submission (
614+ image : Mapping [str , Any ],
615+ sequence : Mapping [str , Any ],
616+ expected_base : int ,
617+ descriptor_count : int ,
618+ ) -> None :
619+ submission = image .get ("submission" )
620+ if not isinstance (submission , Mapping ):
621+ raise ValueError ("prepared execution batch requires descriptor submission metadata" )
622+ base = _aligned_uint32 (
623+ submission .get ("base" ), "prepared execution batch descriptor submission base"
624+ )
625+ head = _nonnegative_int (
626+ submission .get ("head" ), "prepared execution batch descriptor submission head"
627+ )
628+ tail = _nonnegative_int (
629+ submission .get ("tail" ), "prepared execution batch descriptor submission tail"
630+ )
631+ if base != expected_base :
632+ raise ValueError ("prepared execution batch submission base does not match descriptor_base" )
633+ if head != 0 :
634+ raise ValueError ("prepared execution batch submission head must be zero" )
635+ if tail != descriptor_count :
636+ raise ValueError ("prepared execution batch submission tail does not match descriptor_words" )
637+
638+ sequence_submission : dict [str , int ] = {}
549639 for write in _required_sequence_list (sequence , "submission_mmio_writes" ):
550640 if not isinstance (write , Mapping ):
551641 raise TypeError ("host runtime sequence write entry must be a mapping" )
552- if write .get ("register" ) != "DESC_BASE" :
642+ register = write .get ("register" )
643+ if register not in {"DESC_BASE" , "DESC_HEAD" , "DESC_TAIL" }:
553644 continue
554645 _address , value = _sequence_write_address_value (write )
555- if value != expected_base :
556- raise ValueError ("prepared execution batch DESC_BASE does not match descriptor_base" )
557- return
558- raise ValueError ("prepared execution batch requires DESC_BASE submission write" )
646+ sequence_submission [register ] = value
647+ expected_submission = {
648+ "DESC_BASE" : base ,
649+ "DESC_HEAD" : head ,
650+ "DESC_TAIL" : tail ,
651+ }
652+ missing = set (expected_submission ) - set (sequence_submission )
653+ if missing :
654+ raise ValueError ("prepared execution batch submission_mmio_writes missing register" )
655+ if sequence_submission ["DESC_BASE" ] != expected_submission ["DESC_BASE" ]:
656+ raise ValueError ("prepared execution batch DESC_BASE does not match descriptor_base" )
657+ if sequence_submission ["DESC_HEAD" ] != expected_submission ["DESC_HEAD" ]:
658+ raise ValueError ("prepared execution batch DESC_HEAD does not match submission" )
659+ if sequence_submission ["DESC_TAIL" ] != expected_submission ["DESC_TAIL" ]:
660+ raise ValueError ("prepared execution batch DESC_TAIL does not match submission" )
559661
560662
561663def _validate_sequence_descriptor_memory_writes (
0 commit comments