diff --git a/.gitignore b/.gitignore index 568c0e88c2..54eb8ef171 100644 --- a/.gitignore +++ b/.gitignore @@ -54,6 +54,11 @@ tags poetry.lock *.code-workspace .env +settings.yaml +settings.yml +deps/ +FINN_TMP +FINN_IP_CACHE # Package files *.egg diff --git a/README.md b/README.md index ed96e42ae1..013aea69f4 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ FINN+ incorporates all upstream FINN development while adding significant enhanc ### Developer Experience - **Better Diagnostics** - Improved logging and error handling throughout the framework -- **Type Safety** - Comprehensive type hinting and checking for better code quality +- **IP Caching** - IP Caching between builds for faster design iteration - **YAML Configuration** - Alternative YAML-based build configuration system - **Simplified Setup** - Containerless installation and setup process diff --git a/src/finn/builder/build_dataflow.py b/src/finn/builder/build_dataflow.py index 2977539e79..a96467d8c3 100644 --- a/src/finn/builder/build_dataflow.py +++ b/src/finn/builder/build_dataflow.py @@ -54,6 +54,7 @@ from finn.builder.build_dataflow_config import DataflowBuildConfig, default_build_dataflow_steps from finn.builder.build_dataflow_steps import build_dataflow_step_lookup +from finn.transformation.fpgadataflow.ip_cache import CACHE_IP_DEFINITIONS from finn.util.exception import ( FINNConfigurationError, FINNDataflowError, @@ -350,6 +351,18 @@ def build_dataflow_cfg(model_filename, cfg: DataflowBuildConfig): print(f"Final outputs will be generated in {cfg.output_dir}") print(f"Build log is at {logfile}") + # Printing all cached IPs + if cfg.use_ip_caching: + log.info("IP Caching enabled.") + if cfg.verbose: + log.info("Caching enabled for operators: ") + for k, v in CACHE_IP_DEFINITIONS.items(): + log.info(f"Operator: {k}:") + if "use" in v.keys(): + log.info("\tuse: " + ", ".join(v["use"])) + if "ignore" in v.keys(): + log.info("\nignore: " + ", ".join(v["ignore"])) + # Setup done, start build flow try: # If start_step is specified, override the input model diff --git a/src/finn/builder/build_dataflow_config.py b/src/finn/builder/build_dataflow_config.py index 2932d52881..75f6199320 100644 --- a/src/finn/builder/build_dataflow_config.py +++ b/src/finn/builder/build_dataflow_config.py @@ -156,8 +156,7 @@ class VerificationStepType(str, Enum): "step_minimize_bit_width", "step_generate_estimate_reports", "step_set_fifo_depths", - "step_hw_codegen", - "step_hw_ipgen", + "step_ip_generation", "step_create_stitched_ip", "step_measure_rtlsim_performance", "step_out_of_context_synthesis", @@ -334,6 +333,25 @@ class DataflowBuildConfig(DataClassJSONMixin, DataClassYAMLMixin): #: If not specified it will default to synth_clk_period_ns hls_clk_period_ns: Optional[float] = None + #: Use an IP Cache to re-use code-gen (PrepareIP) and HLS (HLSSynthIP) + #: artifacts from previous runs to speed up the build process. + use_ip_caching: bool = True + + #: (Only relevant if use_ip_caching is enabled) + #: Hash function to be used when caching the IP cores. + ip_cache_hashfunction: str = "sha256" + + #: (Only relevant if use_ip_caching is enabled) + #: Whether the value of _resolve_hls_clk_period() is used as part of + #: the cached key. Can be turned off for more cache hits, but + #: then delivers an IP with an outdated constraints file. This + #: might affect OOC Synthesis and other parts of the design, use + #: at your own risk. + cache_hls_clk_period: bool = True + + #: The same as `cache_hls_clk_period`, but for the passed FPGA part. + cache_fpgapart: bool = True + #: (Optional, only relevant when shell_flow_type = VITIS_ALVEO) #: Which Vitis platform will be used, e.g. "xilinx_u250_xdma_201830_2". #: If not specified but "board" is specified, will use the FINN diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index 5838d6a095..2318dc54cd 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -85,6 +85,7 @@ from finn.transformation.fpgadataflow.insert_dwc import InsertDWC from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO from finn.transformation.fpgadataflow.insert_tlastmarker import InsertTLastMarker +from finn.transformation.fpgadataflow.ip_cache import CachedIPGen from finn.transformation.fpgadataflow.make_driver import ( MakeCPPDriver, MakePYNQDriverInstrumentation, @@ -117,7 +118,7 @@ from finn.transformation.streamline.reorder import MakeMaxPoolNHWC from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds from finn.util.basic import get_liveness_threshold_cycles, get_rtlsim_trace_depth -from finn.util.exception import FINNUserError +from finn.util.exception import FINNConfigurationError, FINNUserError from finn.util.logging import log from finn.util.test import execute_parent @@ -521,6 +522,49 @@ def step_minimize_bit_width(model: ModelWrapper, cfg: DataflowBuildConfig): return model +def _make_hls_estimate_report(model: ModelWrapper, cfg: DataflowBuildConfig) -> None: + report_dir = cfg.output_dir + "/report" + os.makedirs(report_dir, exist_ok=True) + estimate_layer_resources_hls = model.analysis(hls_synth_res_estimation) + estimate_layer_resources_hls["total"] = aggregate_dict_keys(estimate_layer_resources_hls) + with open(report_dir + "/estimate_layer_resources_hls.json", "w") as f: + json.dump(estimate_layer_resources_hls, f, indent=2) + + +def step_ip_generation(model: ModelWrapper, cfg: DataflowBuildConfig) -> ModelWrapper: + """Unified step, that does what step_hw_codegen and step_hw_ipgen did before. (With cache!).""" + if cfg.use_ip_caching: + clk = cfg._resolve_hls_clk_period() + if clk is None: + # TODO: Change into a logging error instead of an exception? + raise FINNConfigurationError( + "Please specify synth_clk_period_ns in your build " + "config (and optionally hls_clk_period_ns) before " + "generating IPs!" + ) + model = model.transform( + CachedIPGen( + cfg.ip_cache_hashfunction, + include_prepare_ip=True, + cache_clock=cfg.cache_hls_clk_period, + fpgapart=cfg._resolve_fpga_part(), + clk=clk, + cache_fpgapart=cfg.cache_fpgapart, + ) + ) + else: + model = model.transform(PrepareIP(cfg._resolve_fpga_part(), cfg._resolve_hls_clk_period())) + model = model.transform(HLSSynthIP()) + model = model.transform(ReplaceVerilogRelPaths()) + _make_hls_estimate_report(model, cfg) + + if VerificationStepType.NODE_BY_NODE_RTLSIM in cfg._resolve_verification_steps(): + model = model.transform(PrepareRTLSim()) + model = model.transform(SetExecMode("rtlsim")) + verify_step(model, cfg, "node_by_node_rtlsim", need_parent=True) + return model + + def step_hw_codegen(model: ModelWrapper, cfg: DataflowBuildConfig): """Generate Vitis HLS code to prepare HLSBackend nodes for IP generation. And fills RTL templates for RTLBackend nodes.""" @@ -533,15 +577,36 @@ def step_hw_ipgen(model: ModelWrapper, cfg: DataflowBuildConfig): """Run Vitis HLS synthesis on generated code for HLSBackend nodes, in order to generate IP blocks. For RTL nodes this step does not do anything.""" - model = model.transform(HLSSynthIP()) - model = model.transform(ReplaceVerilogRelPaths()) - report_dir = cfg.output_dir + "/report" - os.makedirs(report_dir, exist_ok=True) - estimate_layer_resources_hls = model.analysis(hls_synth_res_estimation) - estimate_layer_resources_hls["total"] = aggregate_dict_keys(estimate_layer_resources_hls) - with open(report_dir + "/estimate_layer_resources_hls.json", "w") as f: - json.dump(estimate_layer_resources_hls, f, indent=2) + if cfg.use_ip_caching: + log.info("Using IP cache to fetch generated IPs...") + clk = cfg._resolve_hls_clk_period() + if clk is None and cfg.cache_hls_clk_period: + log.critical( + "No HLS/general synthesis clock period was specified, but required for " + "caching (cfg.cache_hls_clk_period). Skipping caching for safety. " + "Executing just HLSSynthIP()..." + ) + model = model.transform(HLSSynthIP()) + else: + # If clk is None but we don't use it anways, give it some placeholder value + if clk is None: + clk = 0 + model = model.transform( + CachedIPGen( + cfg.ip_cache_hashfunction, + cache_clock=cfg.cache_hls_clk_period, + include_prepare_ip=False, + fpgapart=cfg._resolve_fpga_part(), + clk=clk, + cache_fpgapart=cfg.cache_fpgapart, + ) + ) + else: + log.info("Generating all IPs from scratch...") + model = model.transform(HLSSynthIP()) + model = model.transform(ReplaceVerilogRelPaths()) + _make_hls_estimate_report(model, cfg) if VerificationStepType.NODE_BY_NODE_RTLSIM in cfg._resolve_verification_steps(): model = model.transform(PrepareRTLSim()) model = model.transform(SetExecMode("rtlsim")) @@ -1059,6 +1124,7 @@ def step_deployment_package(model: ModelWrapper, cfg: DataflowBuildConfig): "step_apply_folding_config": step_apply_folding_config, "step_minimize_bit_width": step_minimize_bit_width, "step_generate_estimate_reports": step_generate_estimate_reports, + "step_ip_generation": step_ip_generation, "step_hw_codegen": step_hw_codegen, "step_hw_ipgen": step_hw_ipgen, "step_set_fifo_depths": step_set_fifo_depths, diff --git a/src/finn/custom_op/fpgadataflow/hls/__init__.py b/src/finn/custom_op/fpgadataflow/hls/__init__.py index ecd53231fe..8c07888627 100644 --- a/src/finn/custom_op/fpgadataflow/hls/__init__.py +++ b/src/finn/custom_op/fpgadataflow/hls/__init__.py @@ -28,6 +28,7 @@ from finn.custom_op.fpgadataflow.hlsbackend import HLSBackend from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp +from finn.transformation.fpgadataflow.ip_cache import cache_ip # Dictionary of HLSBackend implementations custom_op = dict() @@ -117,3 +118,8 @@ def register_custom_op(cls): custom_op["SplitMultiHeads_hls"] = SplitMultiHeads_hls custom_op["MergeMultiHeads_hls"] = MergeMultiHeads_hls custom_op["ReplicateStream_hls"] = ReplicateStream_hls + +# Apply cache to all ops +for key in custom_op.keys(): + if issubclass(custom_op[key], HWCustomOp): + custom_op[key] = cache_ip(attributes=None)(custom_op[key]) diff --git a/src/finn/custom_op/fpgadataflow/rtl/__init__.py b/src/finn/custom_op/fpgadataflow/rtl/__init__.py index 06067a4fca..1f5e54e99a 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/__init__.py +++ b/src/finn/custom_op/fpgadataflow/rtl/__init__.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp from finn.custom_op.fpgadataflow.rtl.convolutioninputgenerator_rtl import ( ConvolutionInputGenerator_rtl, ) @@ -37,6 +38,7 @@ from finn.custom_op.fpgadataflow.rtl.streamingfifo_rtl import StreamingFIFO_rtl from finn.custom_op.fpgadataflow.rtl.thresholding_rtl import Thresholding_rtl from finn.custom_op.fpgadataflow.rtl.vectorvectoractivation_rtl import VVAU_rtl +from finn.transformation.fpgadataflow.ip_cache import cache_ip custom_op = dict() @@ -49,3 +51,8 @@ custom_op["MVAU_rtl"] = MVAU_rtl custom_op["VVAU_rtl"] = VVAU_rtl custom_op["Thresholding_rtl"] = Thresholding_rtl + +# Apply cache to all ops +for key in custom_op.keys(): + if issubclass(custom_op[key], HWCustomOp): + custom_op[key] = cache_ip(attributes=None)(custom_op[key]) diff --git a/src/finn/interface/interface_utils.py b/src/finn/interface/interface_utils.py index 8d7a40b042..7f6a4a20b9 100644 --- a/src/finn/interface/interface_utils.py +++ b/src/finn/interface/interface_utils.py @@ -133,6 +133,27 @@ def resolve_deps_path(deps: Path | None, settings: dict) -> Path | None: return None +def resolve_cache_path(cache: Path | None, settings: dict) -> Path: + """Resolve the path to the IP cache. Always returns a valid Path. + + Resolution order is: + Command Line Argument -> Environment var -> Settings -> Default (finn-plus/FINN_IP_CACHE) + """ + if cache is not None: + return cache + if "FINN_IP_CACHE" in os.environ.keys(): + p = Path(os.environ["FINN_IP_CACHE"]) + if p.is_absolute(): + return p + return Path(__file__).parent.parent.parent.parent / p + if "FINN_IP_CACHE" in settings.keys(): + p = Path(settings["FINN_IP_CACHE"]) + if p.is_absolute(): + return p + return Path(__file__).parent.parent.parent.parent / p + return Path(__file__).parent.parent.parent.parent / "FINN_IP_CACHE" + + def resolve_num_workers(num: int, settings: dict) -> int: """Resolve the number of workers to use. Uses 75% of cores available as default fallback""" if num > -1: diff --git a/src/finn/interface/run_finn.py b/src/finn/interface/run_finn.py index d2997ecdac..68395476f8 100644 --- a/src/finn/interface/run_finn.py +++ b/src/finn/interface/run_finn.py @@ -22,6 +22,7 @@ assert_path_valid, error, resolve_build_dir, + resolve_cache_path, resolve_deps_path, resolve_num_workers, set_synthesis_tools_paths, @@ -52,16 +53,20 @@ def _resolve_module_path(name: str) -> str: def prepare_finn( deps: Path | None, + cache_path: Path | None, flow_config: Path, build_dir: Path | None, num_workers: int, is_test_run: bool = False, skip_dep_update: bool = False, ) -> None: - """Prepare a FINN environment by: + """Prepare a FINN environment. Leaves this process ready to run any FINN related script. + + This is done by: 0. Reading all settings and environment vars 1. Updating all dependencies 2. Setting all environment vars + 3. Installing depdendencies """ # Resolve settings and dependencies, error if this doesnt work if not settings_found(): @@ -70,6 +75,8 @@ def prepare_finn( sp = _resolve_settings_path() status(f"Using settings file at {sp}") settings = get_settings(force_update=True) + + # Set deps envvar deps_path = resolve_deps_path(deps, settings) if deps_path is None: error("Dependency location could not be resolved!") @@ -78,6 +85,12 @@ def prepare_finn( status(f"Using dependency path: {deps_path}") os.environ["FINN_DEPS"] = str(deps_path.absolute()) + # Set cache envvar + resolved_cache_path = str(resolve_cache_path(cache_path, settings).absolute()) + os.environ["FINN_IP_CACHE"] = resolved_cache_path + status(f"IP Cache set to: {resolved_cache_path}") + + # Clear PYTHONPATH if "PYTHONPATH" not in os.environ.keys(): os.environ["PYTHONPATH"] = "" @@ -132,6 +145,7 @@ def main_group() -> None: @click.command(help="Build a hardware design") @click.option("--dependency-path", "-d", default="") @click.option("--build-path", "-b", help="Specify a build temp path of your choice", default="") +@click.option("--ip-cache-path", "-c", help="Path to the FINN IP Cache directory", default="") @click.option( "--num-workers", "-n", @@ -163,6 +177,7 @@ def main_group() -> None: def build( dependency_path: str, build_path: str, + ip_cache_path: str, num_workers: int, skip_dep_update: bool, start: str, @@ -176,9 +191,11 @@ def build( assert_path_valid(config_path) assert_path_valid(model_path) dep_path = Path(dependency_path).expanduser() if dependency_path != "" else None + cache_path = Path(ip_cache_path).expanduser() if ip_cache_path != "" else None status(f"Starting FINN build with config {config_path.name} and model {model_path.name}!") prepare_finn( dep_path, + cache_path, config_path, build_dir, num_workers, @@ -234,6 +251,7 @@ def build( @click.command(help="Run a script in a FINN environment") @click.option("--dependency-path", "-d", default="") @click.option("--build-path", "-b", help="Specify a build temp path of your choice", default="") +@click.option("--ip-cache-path", "-c", help="Path to the FINN IP Cache directory", default="") @click.option( "--skip-dep-update", "-s", @@ -250,14 +268,21 @@ def build( ) @click.argument("script") def run( - dependency_path: str, build_path: str, skip_dep_update: bool, num_workers: int, script: str + dependency_path: str, + build_path: str, + ip_cache_path: str, + skip_dep_update: bool, + num_workers: int, + script: str, ) -> None: script_path = Path(script).expanduser() build_dir = Path(build_path).expanduser() if build_path != "" else None assert_path_valid(script_path) dep_path = Path(dependency_path).expanduser() if dependency_path != "" else None + cache_path = Path(ip_cache_path).expanduser() if ip_cache_path != "" else None prepare_finn( dep_path, + cache_path, script_path, build_dir, num_workers, @@ -286,7 +311,7 @@ def bench(bench_config: str, dependency_path: str, num_workers: int, build_path: console = Console() build_dir = Path(build_path).expanduser() if build_path != "" else None dep_path = Path(dependency_path).expanduser() if dependency_path != "" else None - prepare_finn(dep_path, Path(), build_dir, num_workers) + prepare_finn(dep_path, None, Path(), build_dir, num_workers) console.rule("RUNNING BENCHMARK") # Late import because we need prepare_finn to setup remaining dependencies first @@ -319,7 +344,7 @@ def test( console = Console() build_dir = Path(build_path).expanduser() if build_path != "" else None dep_path = Path(dependency_path).expanduser() if dependency_path != "" else None - prepare_finn(dep_path, Path(), build_dir, num_workers, is_test_run=True) + prepare_finn(dep_path, None, Path(), build_dir, num_workers, is_test_run=True) status(f"Using {num_test_workers} test workers") console.rule("RUNNING TESTS") run_test(variant, num_test_workers) @@ -340,7 +365,7 @@ def deps() -> None: ) def update(path: str) -> None: dep_path = Path(path).expanduser() if path != "" else None - prepare_finn(dep_path, Path(), None, 1) + prepare_finn(deps=dep_path, cache_path=None, flow_config=Path(), build_dir=None, num_workers=1) @click.group(help="Manage FINN settings") @@ -395,7 +420,7 @@ def config_set(key: str, value: str) -> None: @click.command( "create", help="Create a template settings file. If one exists at the given path, " - "its overwritten. Please enter a directory, no filename", + "its overwritten. Please enter a directory, not a filename", ) @click.argument("path", default="~/.finn/") def config_create(path: str) -> None: diff --git a/src/finn/transformation/fpgadataflow/ip_cache.py b/src/finn/transformation/fpgadataflow/ip_cache.py new file mode 100644 index 0000000000..4278017be6 --- /dev/null +++ b/src/finn/transformation/fpgadataflow/ip_cache.py @@ -0,0 +1,642 @@ +"""Manage IP caching for FINN.""" + +from __future__ import annotations + +import hashlib +import json +import numpy as np +import os +import shlex +import shutil +import subprocess +import sys +from concurrent.futures import Future, ThreadPoolExecutor +from pathlib import Path +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.base import Transformation +from qonnx.util.basic import get_num_default_workers +from typing import TYPE_CHECKING, Any, Callable, Final, cast + +from finn.custom_op.fpgadataflow.attention import ScaledDotProductAttention +from finn.custom_op.fpgadataflow.channelwise_op import ChannelwiseOp +from finn.custom_op.fpgadataflow.elementwise_binary import ElementwiseBinaryOperation +from finn.custom_op.fpgadataflow.hlsbackend import HLSBackend +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp +from finn.custom_op.fpgadataflow.lookup import Lookup +from finn.custom_op.fpgadataflow.matrixvectoractivation import MVAU +from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend +from finn.custom_op.fpgadataflow.thresholding import Thresholding +from finn.custom_op.fpgadataflow.vectorvectoractivation import VVAU +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.util.basic import make_build_dir +from finn.util.deps import get_cache_path, get_deps_path +from finn.util.exception import FINNConfigurationError, FINNInternalError +from finn.util.fpgadataflow import is_hls_node, is_rtl_node +from finn.util.logging import log + +if TYPE_CHECKING: + from onnx import NodeProto + from qonnx.core.modelwrapper import ModelWrapper + + +# UTILITY FUNCTIONS +def _ndarray_to_bytes(tensor: Any) -> bytes: + cont = np.ascontiguousarray(tensor) + assert type(tensor) is np.ndarray + return cont.tobytes() + str(tensor.shape).encode("UTF-8") + + +def _attribute_path_exists(name: str, op: HWCustomOp) -> bool: + """Check that the node attribute path exists. + If the node attribute cannot be loaded, return False.""" # noqa + try: + data = op.get_nodeattr(name) + if data is None or data == "": + return False + return Path(cast(str, data)).exists() + except Exception: + return False + + +def _check_path_lengths_okay( + pc_name_max: int, pc_path_max: int, hashed_key: str, target_dir: Path +) -> bool: + """Check if we follow the path length limits. If not return False, otherwise True.""" + if len(hashed_key) > pc_name_max: + log.error( + f"Cannot cache an IP: The hash hex representation " + f"is too long to be allowed as a filename on your " + f"system (best effort detected limit: " + f"{pc_name_max}). Skipping caching." + ) + return False + path_bytes = len(str(target_dir.absolute()).encode("UTF-8")) + if path_bytes > pc_path_max: + log.error( + f"Cannot cache an IP: the generated path length of " + f"the cache location is not allowed on your system! " + f"The best effort detected limit is: " + f"{pc_path_max} bytes, the path length is " + f"{path_bytes} bytes. Skipping caching." + ) + return False + return True + + +CACHE_IP_DEFINITIONS: dict[type, dict[str, list[str]]] = {} +"""Contains all node attributes that a custom operator needs to be characterized. +Filled by the cache_ip decorator. If the field "use" is defined, these attributes are +used to hash the op. +>>> CACHE_IP_DEFINITIONS[my_operator]["use"] = [...] + +However if "ignore" is used, every attribute _except_ those listed are used. +>>> CACHE_IP_DEFINITIONS[my_operator]["ignore"] = [...] +""" + + +def cache_ip(attributes: list[str] | None = None) -> Callable[[type], type]: + """Decorate the given custom operator to be cacheable. + + Args: + attributes: List of the key names of all node attributes needed to + identify IP cores. + """ + global CACHE_IP_DEFINITIONS + + def wrapper(op_cls: type) -> type: + assert issubclass( + op_cls, HWCustomOp + ), f"Can only cache HWCustomOp instances, but {op_cls.__name__} is not a HWCustomOP!" + if op_cls not in CACHE_IP_DEFINITIONS.keys(): + CACHE_IP_DEFINITIONS[op_cls] = {} + else: + # Already marked + return op_cls + if attributes is not None: + CACHE_IP_DEFINITIONS[op_cls]["use"] = attributes + else: + # List of fields that don't define the IP core itself, + # and can thus be ignored when hashing + ignore_fields = [ + "code_gen_dir_ipgen", + "gen_top_module", + "ip_vlnv", + "ipgen_path", + "ip_path", + "cycles_rtlsim", + "cycles_estimate", + "res_estimate", + "res_synth", + "rtlsim_so", + "executable_path", + "res_hls", + "code_gen_dir_cppsim", + ] + CACHE_IP_DEFINITIONS[op_cls]["ignore"] = ignore_fields + return op_cls + + return wrapper + + +class IPCache: + """Manage IP caching. + + Public methods that are relevant for the caches usage: + - `model = cache.apply(model)`: Fetch cached IPs and apply them to the model, + returning the new model + - `cache.update(model)`: Update the cache by adding synthesized IPs that are not + yet cached into the cache. + - `cache.get_key(op, model)`: Get the key (string) of the given custom op + - `cache.get_hash_hex(key)`: Get the hex representation of the hash of the given key. + - `cache.get_num_cached_ips(model)`: Get the number of cached IPs in the given model. + """ + + allowed_hashfuncs: Final[list[str]] = ["sha256", "sha512", "blake2s", "blake2b"] + + def __init__( + self, + cache_dir: Path, + hashfunc: str, + hls_clk_period: float, + cache_hls_clk: bool, + fpgapart: str, + cache_fpgapart: bool, + ) -> None: + """Construct a new IPCache object. + + Args: + cache_dir: The path of the cache directory. + hashfunc: The name of the hash function to be used. + hls_clk_period: HLS clock period in ns. + cache_hls_clk: Use the HLS clock as part of the key. + fpgapart: FPGA-part used for HLSSynth and PrepareIP. + cache_fpgapart: Use the fpgapart as part of the key. + """ + self.cache_dir = cache_dir + self.cache_hls_clk = cache_hls_clk + self.cache_fpgapart = cache_fpgapart + + # Used to check validity of cache directory names + if sys.platform != "win32": + self.max_hash_len = os.pathconf("/", "PC_NAME_MAX") + self.max_path_len = os.pathconf("/", "PC_PATH_MAX") + else: + # TODO: Implement filesystem checks + # 256 seems to be the default max path length under windows + self.max_hash_len = 256 + self.max_path_len = 256 + + if not self.cache_dir.exists(): + self.cache_dir.mkdir() + log.info(f"Opened cache handler. Cache directory: {self.cache_dir}") + if hashfunc not in dir(hashlib): + raise FINNConfigurationError(f"There is no hash function with the name {hashfunc}!") + if hashfunc not in self.allowed_hashfuncs: + raise FINNConfigurationError( + f"Hash function {hashfunc} not available for caching. " + f"Choose one of: {self.allowed_hashfuncs}" + ) + + self.hashfunc_name = hashfunc + self.hasher: Callable = getattr(hashlib, hashfunc) + + # Prepare some always needed values + # FINN Commit + self.finn_commit = subprocess.run( + shlex.split("git rev-parse HEAD"), + text=True, + capture_output=True, + cwd=Path(__file__).parent, + ).stdout.strip() + self.finn_commit_time = subprocess.run( + shlex.split("git show --quiet --format=%ai"), + text=True, + capture_output=True, + cwd=Path(__file__).parent, + ).stdout.strip() + log.info(f"FINN Commit reads: {self.finn_commit} (authored at: {self.finn_commit_time})") + + # FINN HLSLIB Commit + self.hlslib_commit = subprocess.run( + shlex.split("git rev-parse HEAD"), + text=True, + capture_output=True, + cwd=get_deps_path() / "finn-hlslib", + ).stdout.strip() + self.hlslib_commit_time = subprocess.run( + shlex.split("git show --quiet --format=%ai"), + text=True, + capture_output=True, + cwd=get_deps_path() / "finn-hlslib", + ).stdout.strip() + log.info( + f"HLSLIB Commit reads: {self.hlslib_commit} " + f"(authored at: {self.hlslib_commit_time})" + ) + + # HLS Clk and device + self.clk = hls_clk_period + self.fpgapart = fpgapart + + def _get_key_part_attributes(self, op: HWCustomOp) -> str: + """Return the part of the key that contains attributes and their values.""" + key_part = "" + typ = type(op) + attrs: list[str] = [] + if "use" in CACHE_IP_DEFINITIONS[typ].keys(): + attrs = CACHE_IP_DEFINITIONS[typ]["use"] + elif "ignore" in CACHE_IP_DEFINITIONS[typ].keys(): + attrs = [ + k + for k in op.get_nodeattr_types().keys() + if k not in CACHE_IP_DEFINITIONS[typ]["ignore"] + ] + else: + raise FINNInternalError("This codepath should not be reachable!") + for attr in attrs: + data = None + try: + data = op.get_nodeattr(attr) + except Exception: + continue + try: + data = str(data) + except Exception as e: + raise FINNInternalError( + f"Unable to create string-representation for node " + f"attribute {attr} of custom op {op.onnx_node.name} of " + f"type {type(op)}." + ) from e + key_part += f"{attr}:{data}\n" + return key_part + + def _get_key_part_parameter(self, op: HWCustomOp, model: ModelWrapper) -> str: + """Get the key part defined by the op parameters. + + If, for example, weights, are embedded into the operators, they need to + be part of the hashed key as well. + """ + if isinstance(op, (MVAU, VVAU)): + mem_mode = None + try: + mem_mode = op.get_nodeattr("mem_mode") + except Exception as e: + raise FINNInternalError( + f"Cannot cache {op.onnx_node.name} because op is of " + f"type MVAU but has no mem_mode set!" + ) from e + if mem_mode in ["internal_embedded", "internal_decoupled"]: + weightbytes = _ndarray_to_bytes(model.get_initializer(op.onnx_node.input[1])) + try: + threshbytes = _ndarray_to_bytes(model.get_initializer(op.onnx_node.input[2])) + except IndexError: + # No thresholds + threshbytes = b"" + array_hash = self.hasher(weightbytes + threshbytes).hexdigest() + return f"param_hash:{array_hash}\n" + elif isinstance(op, (Thresholding, ChannelwiseOp, Lookup)): + parambytes = _ndarray_to_bytes(model.get_initializer(op.onnx_node.input[1])) + array_hash = self.hasher(parambytes).hexdigest() + return f"param_hash:{array_hash}\n" + elif isinstance(op, (ElementwiseBinaryOperation,)): + parambytes0 = _ndarray_to_bytes(model.get_initializer(op.onnx_node.input[0])) + parambytes1 = _ndarray_to_bytes(model.get_initializer(op.onnx_node.input[1])) + array_hash = self.hasher(parambytes0 + parambytes1).hexdigest() + return f"param_hash:{array_hash}\n" + elif isinstance(op, ScaledDotProductAttention): + key_part = "" + if op.get_nodeattr("ActQKMatMul") == "thresholds": + thresholds = model.get_initializer( + op.get_input_name_by_name("thresholds_qk_matmul") + ) + hashed = self.hasher(_ndarray_to_bytes(thresholds)).hexdigest() + key_part += f"thresholds_qk_matmul:{hashed}\n" + if op.get_nodeattr("ActASoftmax") == "thresholds": + thresholds = model.get_initializer( + op.get_input_name_by_name("thresholds_a_softmax") + ) + hashed = self.hasher(_ndarray_to_bytes(thresholds)).hexdigest() + key_part += f"thresholds_a_softmax:{hashed}\n" + if op.get_nodeattr("ActAVMatMul") == "thresholds": + thresholds = model.get_initializer( + op.get_input_name_by_name("thresholds_av_matmul") + ) + hashed = self.hasher(_ndarray_to_bytes(thresholds)).hexdigest() + key_part += f"thresholds_av_matmul:{hashed}\n" + if op.get_nodeattr("mask_mode") == "const": + mask = model.get_initializer(op.get_input_name_by_name("M")) + hashed = self.hasher(_ndarray_to_bytes(mask)).hexdigest() + key_part += f"M:{hashed}\n" + return key_part + return "" + + def get_key(self, op: HWCustomOp, model: ModelWrapper) -> str: + """Return the key that can be hashed, for the given custom op. + + These parts are used to build the key which is then hashed for the cache: + - FINN commit + - FINN-HLSLIB commit + - Custom Op type + - (Optional) HLS clock + - (Optional) HLS Synthesis FPGA-part + - All node attributes that define a unique instance of the operator (set by @cache_ip(...)) + - All external parameters for ops that have these (for example MVAU) + - These are hashed themselves for brevity, otherwise the key might be megabytes of data + + **IMPORTANT**: Keep in mind that changes in this function will require caching everything + again. + + Returns: + str: The human-readable key. Can be used to generate the caching + hash and the metadata file packed with the cached data. + """ + global CACHE_IP_DEFINITIONS + if type(op) not in CACHE_IP_DEFINITIONS.keys(): + log.error( + f"Tried getting the key for a non-cacheable custom operator ({type(op).__name__}). " + "Did you perhaps forget to register the op for caching via " + "@cache_ip(...)?" + ) + key = f"FINN: {self.finn_commit}\nHLSLIB: {self.hlslib_commit}\n" + key += "type:" + type(op).__name__ + "\n" + if self.cache_hls_clk: + key += f"hls_clk_period_ns:{self.clk}\n" + if self.cache_fpgapart: + key += f"fpgapart:{self.fpgapart}\n" + key += self._get_key_part_attributes(op) + "\n" + key += self._get_key_part_parameter(op, model) + return key + + def get_hash_hex(self, key: str) -> str: + """Return the hex repr of the hash of the given key.""" + return self.hasher(key.encode("UTF-8")).hexdigest() + + def _create_key_file(self, key: str, path: Path) -> None: + """Write the given key data into a file at the given path.""" + with path.open("w+") as f: + f.write(f"Hashed using {self.hashfunc_name}.\n") + f.write(f"Final overall hashed key: {self.get_hash_hex(key)}") + f.write(f"FINN Commit Date: {self.finn_commit_time}\n") + f.write(f"FINN HLSLIB Commit Date: {self.hlslib_commit_time}\n") + f.write("Key:\n------------------------\n") + f.write(key) + + def _dump_nodeattrs( + self, op: HWCustomOp, path: Path, additional_attributes: list[str] | None = None + ) -> None: + """Dump the custom ops node attributes at the given path as a JSON. + + If a node attribute cannot be accessed, it is silently ignored. + + Args: + op: The HWCustom op of which the node attributes are the target + path: Where to dump the node attributes + additional_attributes: A list of additional attribute keys that + should be included in the dump. + """ + if additional_attributes is None: + additional_attributes = [] + required = {"ip_vlnv", "gen_top_module", *additional_attributes} + d = {} + for name in op.get_nodeattr_types().keys(): + if name in required: + try: + d[name] = op.get_nodeattr(name) + except Exception: + continue + with path.open("w+") as f: + json.dump(d, f) + + @staticmethod + def _replace_modulename(directory: Path, old: str, new: str) -> None: + """Recursively walk the directory and change all file/directory names, as well + as contents in the files from the old string to the new string. + """ # noqa + if not directory.is_dir(): + raise FINNInternalError(f"Cannot replace module names in non-directory: {directory}") + + # Walk all paths recursively + for obj in directory.rglob("*"): + obj: Path + + # Replace file/directory names + if old in obj.name: + new_path = obj.with_name(obj.name.replace(old, new)) + obj.rename(new_path) + obj = new_path + + # Replace contents in files + if obj.is_file(): + try: + text = obj.read_text() + except UnicodeDecodeError: + # We might accidentally read a binary file + # In that case just move on + continue + obj.write_text(text.replace(old, new)) + + @staticmethod + def _prepare_from_cached_ip( + op: HWCustomOp, hashed_key: str, make_copy: bool, cache_dir: Path + ) -> None: + """Prepare the given custom op for usage of the given cached IP. + + We have to set some node attributes normally set by HLSSynth and PrepareIP. This needs to + be done to use the cached IP. + + Args: + op: The operator of which the node attributes we have to modify. + hashed_key: The hash hex repr of the key for this op. Used to find the cached IP. + make_copy: If True, first makes a copy of the cached IP in the current FINN_BUILD_DIR + and sets the path towards this copy instead of the cached original. + cache_dir: FINN_IP_CACHE directory, as passed from the calling IPCache instance. + """ + log.info(f"Preparing {op.onnx_node.name} from cached IP (hashed key: {hashed_key[:10]}...)") + ip_dir = cache_dir / hashed_key + saved_nodeattrs = {} + + # Check if the cached IP really exists + if not ip_dir.exists(): + raise FINNInternalError( + f"Cannot use hashed key {hashed_key}: Cache dir {ip_dir} does not exist!" + ) + + # Read node attributes from saved directory + with (ip_dir / "nodeattrs.json").open("r") as f: + saved_nodeattrs = json.load(f) + + # If needed make copy of the cached dir + if make_copy: + copied_dir = Path(make_build_dir(prefix=f"cached_code_gen_ipgen_{op.onnx_node.name}")) + shutil.copytree(ip_dir, copied_dir, dirs_exist_ok=True) + ip_dir = copied_dir + + # Set node attributes correctly to point to cached directory + op.set_nodeattr("code_gen_dir_ipgen", str(ip_dir)) + if issubclass(type(op), RTLBackend): + # Rename module in filenames and contents from the cached name to applied node name + old_module_name = saved_nodeattrs["gen_top_module"] + new_module_name = op.get_verilog_top_module_name() + if old_module_name != new_module_name: + log.debug( + f"{op.onnx_node.name}: Replacing cached module name: {old_module_name} " + f"with applied module name: {new_module_name}" + ) + IPCache._replace_modulename(ip_dir, old_module_name, new_module_name) + op.set_nodeattr("ip_path", str(ip_dir)) + op.set_nodeattr("ipgen_path", str(ip_dir)) + op.set_nodeattr("gen_top_module", new_module_name) + + elif issubclass(type(op), HLSBackend): + op.set_nodeattr("ip_vlnv", saved_nodeattrs["ip_vlnv"]) + op.set_nodeattr( + "ip_path", str(ip_dir / f"project_{op.onnx_node.name}" / "sol1" / "impl" / "ip") + ) + op.set_nodeattr("ipgen_path", str(ip_dir / f"project_{op.onnx_node.name}")) + + def _get_node_data( + self, node: NodeProto, model: ModelWrapper + ) -> tuple[HWCustomOp, str, str, Path]: + """Return the op, key, hashed key, cache dir path for a given node.""" + op = getCustomOp(node) + key = self.get_key(op, model) + hashed_key = self.get_hash_hex(key) + return op, key, hashed_key, self.cache_dir / hashed_key + + def _is_op_synthesized(self, op: HWCustomOp) -> bool: + """Return whether the given op is synthesized. This is derived from the existence and + validity of the paths in code_gen_dir_ipgen, ipgen_path and ip_path.""" # noqa + return ( + _attribute_path_exists("code_gen_dir_ipgen", op) + and _attribute_path_exists("ip_path", op) + and _attribute_path_exists("ipgen_path", op) + ) + + def get_num_cached_ips(self, model: ModelWrapper) -> int: + """Return the number of cached IPs in the model.""" + count = 0 + for node in model.graph.node: + _, _, _, cache_dir = self._get_node_data(node, model) + if cache_dir.exists(): + count += 1 + return count + + def apply(self, model: ModelWrapper) -> ModelWrapper: + """Apply all IPs that were cached to the model and return it.""" + futures: list[Future] = [] + with ThreadPoolExecutor(max_workers=get_num_default_workers()) as pool: + for node in model.graph.node: + op, _, hashed_key, op_cache_dir = self._get_node_data(node, model) + if op_cache_dir.exists(): + futures.append( + pool.submit( + IPCache._prepare_from_cached_ip, + op=op, + hashed_key=hashed_key, + make_copy=True, + cache_dir=self.cache_dir, + ) + ) + pool.shutdown(wait=True) + + # Raise exceptions from threads if there were any + for future in futures: + _ = future.result() + return model + + def update(self, model: ModelWrapper) -> None: + """Check a model for generated IPs that were not yet cached, and cache them. + + Requires HLSSynthIP() to be run before. + """ + total_cached = 0 + for node in model.graph.node: + op, key, hashed_key, target_dir = self._get_node_data(node, model) + if not _check_path_lengths_okay( + self.max_hash_len, self.max_path_len, hashed_key, target_dir + ): + return + if not (is_hls_node(node) or is_rtl_node(node)): + log.warning(f"Cannot cache node {node.name}. Node is not a HW node!") + continue + if not target_dir.exists(): + if not self._is_op_synthesized(op): + log.warning( + f"{node.name} hasn't been synthesized yet and can't be cached " + f"(one of code_gen_dir_ipgen, ip_path, ipgen_path is missing or " + f"invalid!). Hash after synthesis will be: {hashed_key}" + ) + continue + code_gen_dir = Path(cast(str, op.get_nodeattr("code_gen_dir_ipgen"))) + if not code_gen_dir.exists(): + log.warning( + f"Could not cache {node.name}: code_gen_dir_ipgen not set. " + f"Did HLSSynthIP() fail/was not run?" + ) + shutil.copytree(code_gen_dir, target_dir, dirs_exist_ok=True) + self._create_key_file(key, target_dir / "key.txt") + self._dump_nodeattrs(op, target_dir / "nodeattrs.json") + log.info(f"Cached node {node.name}. Cached at: {target_dir} from {code_gen_dir}!") + total_cached += 1 + log.info(f"Cached a total of {total_cached} new ops.") + + +class CachedIPGen(Transformation): + """(PrepareIP and) HLSSynth but cached.""" + + def __init__( + self, + hash_function: str, + include_prepare_ip: bool, + clk: float, + cache_clock: bool, + fpgapart: str, + cache_fpgapart: bool, + ) -> None: + """(PrepareIP and) HLSSynth but cached. + + Args: + hash_function: Hashfunction to use. + include_prepare_ip: If True, also run PrepareIP before synthesis. + fpgapart: Required if PrepareIP is being run. + cache_fpgapart: Whether or not to use the fpgapart for the cache ky + clk: Required if PrepareIP is being run. + cache_clock: Whether or not to use the clock for the cache key + """ + super().__init__() + self.hashfunc = hash_function + self.prepareip = include_prepare_ip + self.part = fpgapart + self.cache_part = cache_fpgapart + self.clk = clk + self.cache_clock = cache_clock + + def apply(self, model: ModelWrapper) -> tuple[ModelWrapper, bool]: + """Apply cached HLS Synthesis (and PrepareIP).""" + cache = IPCache( + cache_dir=get_cache_path(), + hashfunc=self.hashfunc, + hls_clk_period=self.clk, + cache_hls_clk=self.cache_clock, + fpgapart=self.part, + cache_fpgapart=self.cache_part, + ) + log.info( + f"Applying cache to {cache.get_num_cached_ips(model)} " + f"/ {len(model.graph.node)} nodes!" + ) + model = cache.apply(model) + if self.prepareip: + if self.part is None or self.clk is None: + raise FINNInternalError( + "Cannot run PrepareIP in CachedIPGen without fpgapart and clk being passed!" + ) + log.info("Running PrepareIP for uncached IPs...") + model = model.transform(PrepareIP(self.part, self.clk)) + cache.update(model) + log.info("Running synthesis for uncached IPs...") + model = model.transform(HLSSynthIP()) + log.info("Updating cache with newly generated IPs...") + cache.update(model) + return model, False diff --git a/src/finn/util/deps.py b/src/finn/util/deps.py index 8728481e75..5512b91f15 100644 --- a/src/finn/util/deps.py +++ b/src/finn/util/deps.py @@ -1,6 +1,8 @@ import os from pathlib import Path +from finn.util.exception import FINNInternalError + def get_deps_path() -> Path: """Get the dependency path from the environment variable. @@ -8,3 +10,14 @@ def get_deps_path() -> Path: if "FINN_DEPS" not in os.environ.keys(): return Path.home() / ".finn" / "deps" return Path(os.environ["FINN_DEPS"]) + + +def get_cache_path() -> Path: + """Return the path to the cache.""" + if "FINN_IP_CACHE" not in os.environ.keys(): + raise FINNInternalError( + "FINN_IP_CACHE environment variable not found! This may be a " + "bug, since the setup (run_finn.py) should always set this " + "variable!" + ) + return Path(os.environ["FINN_IP_CACHE"]) diff --git a/tests/infrastructure/test_ip_cache.py b/tests/infrastructure/test_ip_cache.py new file mode 100644 index 0000000000..72b7cd1822 --- /dev/null +++ b/tests/infrastructure/test_ip_cache.py @@ -0,0 +1,225 @@ +"""Test that the IP cache is working correctly. (No false positives, no collisions, speed, etc.).""" +from __future__ import annotations + +import pytest + +import numpy as np +import os +import time +from copy import deepcopy +from pathlib import Path +from qonnx.core.datatype import DataType +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames +from qonnx.util.basic import gen_finn_dt_tensor +from typing import TYPE_CHECKING, Literal, cast + +from finn.custom_op.fpgadataflow.hls.matrixvectoractivation_hls import MVAU_hls +from finn.custom_op.fpgadataflow.hlsbackend import HLSBackend +from finn.custom_op.fpgadataflow.rtl.matrixvectoractivation_rtl import MVAU_rtl +from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend +from finn.transformation.fpgadataflow.ip_cache import CachedIPGen, IPCache +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers +from finn.util.basic import alveo_part_map +from finn.util.deps import get_cache_path +from tests.fpgadataflow.test_fpgadataflow_mvau import make_single_fclayer_modelwrapper + +if TYPE_CHECKING: + from qonnx.core.modelwrapper import ModelWrapper + + from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + + +def mvau_create_model( + fpgapart: str, mode: Literal["hls", "rtl"] +) -> tuple[ModelWrapper, np.ndarray, np.ndarray]: + """Create and sanity check a model for testing MVAU caching. + + Returns: + ModelWrapper, NDArray, NDArray: Model, weights, thresholds. + """ + # TODO: Fix gen_finn_dt_tensor issue in our QONNX (same values + # for subsequent calls of the function) + W = gen_finn_dt_tensor(DataType["INT4"], (10, 10), seed=1) + T = gen_finn_dt_tensor(DataType["INT4"], (10, 10), seed=1) + + # Creating the model + model = make_single_fclayer_modelwrapper( + W, 1, 1, DataType["INT4"], DataType["INT4"], DataType["INT4"], T, DataType["INT4"] + ) + + op: HWCustomOp = getCustomOp(model.graph.node[0]) + op.set_nodeattr("preferred_impl_style", mode) + if mode == "rtl": + # Required to set MVAU implementation to rtl + op.set_nodeattr("noActivation", 1) + op.set_nodeattr("binaryXnorMode", 0) + + model = model.transform(SpecializeLayers(fpgapart)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + + # Some sanity checks + assert model.graph.node[0].op_type == "MVAU_" + mode + assert getCustomOp(model.graph.node[0]).get_nodeattr("mem_mode") in [ + "internal_decoupled", + "internal_embedded", + ] + return model, W, T + + +def mvau_specific_asserts( + model: ModelWrapper, + original_op: HWCustomOp, + original_cache: IPCache, + original_key: str, + W: np.ndarray, + T: np.ndarray, +) -> None: + """Run MVAU specific asserts to validate caching.""" + for attribute in [ + "resType", + "MW", + "MH", + "SIMD", + "PE", + "inputDataType", + "weightDataType", + "outputDataType", + ]: + original_value = original_op.get_nodeattr(attribute) + if attribute in ["MW", "MH", "SIMD", "PE"]: + original_op.set_nodeattr(attribute, original_value + 1) # type: ignore + elif attribute == "resType": + assert original_value == "auto" + original_op.set_nodeattr(attribute, "dsp") + else: + original_op.set_nodeattr(attribute, "UINT6") + assert original_cache.get_key(original_op, model) != original_key + original_op.set_nodeattr(attribute, original_value) + + # Check that the hash changes with the parameters + # Weights + new_W = gen_finn_dt_tensor(DataType["UINT4"], (10, 10), seed=2) + assert not np.array_equal(W, new_W) + weight_init = model.graph.node[0].input[1] + model.set_initializer(weight_init, new_W) + new_key = original_cache.get_key(original_op, model) + assert original_key != new_key + model.set_initializer(weight_init, W) + + # Thresholds + new_T = gen_finn_dt_tensor(DataType["UINT4"], (10, 10), seed=2) + assert not np.array_equal(T, new_T) + thresh_init = model.graph.node[0].input[2] + model.set_initializer(thresh_init, new_T) + new_key = original_cache.get_key(original_op, model) + assert original_key != new_key + model.set_initializer(thresh_init, T) + + +def get_first_op(model: ModelWrapper) -> HWCustomOp: + """Return the op of the first node in the model.""" + return getCustomOp(model.graph.node[0]) + + +@pytest.mark.parametrize("op_type", [MVAU_hls, MVAU_rtl]) +@pytest.mark.parametrize("hashfunc", ["sha256"]) +@pytest.mark.parametrize("fpgapart", [alveo_part_map["U280"]]) +@pytest.mark.parametrize("hls_clk", [2.5]) +def test_ip_hash_key(op_type: type, hashfunc: str, fpgapart: str, hls_clk: float) -> None: + """Test IP Caching. + + To do so, we create models that we then run the cache on. We check, that for + changes in any attribute, external parameter and clock the hash generated changes as well. + We also check, that the generated IP is at the correct path, with all meta-information, + and that subsequent synthesis actually use the cached IP by measuring the time needed + to re-run synthesis on a fresh copy of the original model. + """ + os.environ["FINN_IP_CACHE"] = os.environ["FINN_BUILD_DIR"] + + # Create the model + model: ModelWrapper + if op_type is MVAU_hls: + model, W, T = mvau_create_model(fpgapart, mode="hls") + elif op_type is MVAU_rtl: + model, W, T = mvau_create_model(fpgapart, mode="rtl") + else: + raise AssertionError(f"Cache test for op {op_type.__name__} not yet implemented!") + + # Save a copy of the unsynthesized model for later + unsynth_model = deepcopy(model) + + # Run the cache transformation + model = model.transform( + CachedIPGen( + hash_function=hashfunc, + include_prepare_ip=True, + cache_clock=True, + clk=hls_clk, + cache_fpgapart=True, + fpgapart=fpgapart, + ) + ) + cache = IPCache( + cache_dir=get_cache_path(), + hashfunc=hashfunc, + hls_clk_period=hls_clk, + cache_hls_clk=True, + fpgapart=fpgapart, + cache_fpgapart=True, + ) + original_op = get_first_op(model) + original_key = cache.get_key(original_op, model) + + # Check that the hash changes with the attributes + if op_type in [MVAU_hls, MVAU_rtl]: + mvau_specific_asserts(model, original_op, cache, original_key, W, T) + else: + raise AssertionError(f"{op_type.__name__} specific cache test asserts not yet implemented!") + + # Check that the IP was cached at the correct path + path = cache.cache_dir / cache.get_hash_hex(original_key) + assert path.exists() + assert (path / "nodeattrs.json").exists() + assert (path / "key.txt").exists() + with (path / "key.txt").open("r") as f: + data = f.read() + assert f"type:{op_type.__name__}" in data + assert f"Hashed using {hashfunc}" in data + assert original_key in data + + # Check that a different HLS clk generates a different key + other_clk_cache = IPCache(get_cache_path(), hashfunc, hls_clk + 1.0, True, fpgapart, True) + assert cache.get_key(original_op, model) != other_clk_cache.get_key(original_op, model) + + # Check speed of the second call (should be much faster) + start: float = time.time() + unsynth_model = unsynth_model.transform( + CachedIPGen(hashfunc, True, hls_clk, True, fpgapart, True) + ) + ms_elapsed = time.time() - start + + # Time in seconds that the cached transform may take. + # 10s should be enough, even on slow systems, but if it is clear that + # there isn't a bug, this can be adjusted if it leads to failing + # CI runs. + CACHE_TIME_ALLOWED = 10 + assert ms_elapsed <= 1000 * CACHE_TIME_ALLOWED + + # Check that the cached and re-used IP does exist + first_op = get_first_op(unsynth_model) + codegen_path = Path(cast(str, first_op.get_nodeattr("code_gen_dir_ipgen"))) + if issubclass(op_type, HLSBackend): + expected_ip_path = ( + codegen_path / f"project_{first_op.onnx_node.name}" / "sol1" / "impl" / "ip" + ) + assert expected_ip_path.exists() + elif issubclass(op_type, RTLBackend): + for f in (cache.cache_dir / cache.get_hash_hex(cache.get_key(first_op, model))).iterdir(): + assert (codegen_path / f).exists() + else: + raise AssertionError( + f"{op_type.__name__} doesnt have either an HLS or RTL backend. " + f"Only test subclasses that can actually be cached!" + )