28
28
import numpy as np
29
29
30
30
from .operands import Fetch , ShuffleProxy
31
- from .graph import DAG , DirectedGraph
31
+ from .graph import DAG
32
32
from .config import options
33
33
from .tiles import IterativeChunkGraphBuilder , ChunkGraphBuilder , get_tiled
34
34
from .optimizes .runtime .core import RuntimeOptimizer
@@ -712,7 +712,7 @@ def _on_tile_success(before_tile_data, after_tile_data):
712
712
chunk_result = self ._chunk_result .copy ()
713
713
tileable_graph_builder = TileableGraphBuilder ()
714
714
tileable_graph = tileable_graph_builder .build ([tileable ])
715
- chunk_graph_builder = ChunkGraphBuilder (graph_cls = DirectedGraph , compose = compose ,
715
+ chunk_graph_builder = ChunkGraphBuilder (compose = compose ,
716
716
on_tile_success = _on_tile_success )
717
717
chunk_graph = chunk_graph_builder .build ([tileable ], tileable_graph = tileable_graph )
718
718
ret = self .execute_graph (chunk_graph , result_keys , n_parallel = n_parallel or n_thread ,
@@ -774,6 +774,7 @@ def execute_tileables(self, tileables, fetch=True, n_parallel=None, n_thread=Non
774
774
tileable_data_to_chunks = weakref .WeakKeyDictionary ()
775
775
776
776
node_to_fetch = weakref .WeakKeyDictionary ()
777
+ skipped_tileables = set ()
777
778
778
779
def _generate_fetch_tileable (node ):
779
780
# Attach chunks to fetch tileables to skip tile.
@@ -786,6 +787,22 @@ def _generate_fetch_tileable(node):
786
787
787
788
return node
788
789
790
+ def _skip_executed_tileables (inps ):
791
+ # skip the input that executed, and not gc collected
792
+ new_inps = []
793
+ for inp in inps :
794
+ if inp .key in self .stored_tileables :
795
+ try :
796
+ get_tiled (inp )
797
+ except KeyError :
798
+ new_inps .append (inp )
799
+ else :
800
+ skipped_tileables .add (inp )
801
+ continue
802
+ else :
803
+ new_inps .append (inp )
804
+ return new_inps
805
+
789
806
def _generate_fetch_if_executed (nd ):
790
807
# node processor that if the node is executed
791
808
# replace it with a fetch node
@@ -830,7 +847,8 @@ def _get_tileable_graph_builder(**kwargs):
830
847
with self ._gen_local_context (chunk_result ):
831
848
# build tileable graph
832
849
tileable_graph_builder = _get_tileable_graph_builder (
833
- node_processor = _generate_fetch_tileable )
850
+ node_processor = _generate_fetch_tileable ,
851
+ inputs_selector = _skip_executed_tileables )
834
852
tileable_graph = tileable_graph_builder .build (tileables )
835
853
chunk_graph_builder = IterativeChunkGraphBuilder (
836
854
graph_cls = DAG , node_processor = _generate_fetch_if_executed ,
0 commit comments