From 7072f8f2ab441d8943c5606041c21849480f2605 Mon Sep 17 00:00:00 2001 From: mmrahorovic Date: Fri, 3 May 2024 10:54:38 +0100 Subject: [PATCH] [models]: add argument to provide path to bitfile --- finn_examples/models.py | 91 +++++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 44 deletions(-) diff --git a/finn_examples/models.py b/finn_examples/models.py index c22e190..dfc7dcd 100644 --- a/finn_examples/models.py +++ b/finn_examples/models.py @@ -190,24 +190,27 @@ def get_edge_or_pcie(): raise OSError("Platform is not supported.") -def find_bitfile(model_name, target_platform): - bitfile_exts = {"edge": "bit", "pcie": "xclbin"} - bitfile_ext = bitfile_exts[get_edge_or_pcie()] - bitfile_name = "%s.%s" % (model_name, bitfile_ext) - bitfile_candidates = [ - pk.resource_filename("finn_examples", "bitfiles/%s/%s" % (target_platform, bitfile_name)), - pk.resource_filename( - "finn_examples", - "bitfiles/bitfiles.zip.d/%s/%s" % (target_platform, bitfile_name), - ), - ] - for candidate in bitfile_candidates: - if os.path.isfile(candidate): - return candidate - raise Exception( - "Bitfile for model = %s target platform = %s not found. Looked in: %s" - % (model_name, target_platform, str(bitfile_candidates)) - ) +def find_bitfile(model_name, target_platform, bitfile_path): + if bitfile_path is not None: + return bitfile_path + else: + bitfile_exts = {"edge": "bit", "pcie": "xclbin"} + bitfile_ext = bitfile_exts[get_edge_or_pcie()] + bitfile_name = "%s.%s" % (model_name, bitfile_ext) + bitfile_candidates = [ + pk.resource_filename("finn_examples", "bitfiles/%s/%s" % (target_platform, bitfile_name)), + pk.resource_filename( + "finn_examples", + "bitfiles/bitfiles.zip.d/%s/%s" % (target_platform, bitfile_name), + ), + ] + for candidate in bitfile_candidates: + if os.path.isfile(candidate): + return candidate + raise Exception( + "Bitfile for model = %s target platform = %s not found. Looked in: %s" + % (model_name, target_platform, str(bitfile_candidates)) + ) def find_runtime_weights(model_name, target_platform): @@ -266,75 +269,75 @@ def resolve_target_platform(target_platform): return check_platform_is_valid(platform) -def kws_mlp(target_platform=None): +def kws_mlp(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "kwsmlp-w3a3" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) return FINNExampleOverlay(filename, driver_mode, _gscv2_mlp_io_shape_dict) -def tfc_w1a1_mnist(target_platform=None): +def tfc_w1a1_mnist(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "tfc-w1a1" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) return FINNExampleOverlay(filename, driver_mode, _mnist_fc_io_shape_dict) -def tfc_w1a2_mnist(target_platform=None): +def tfc_w1a2_mnist(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "tfc-w1a2" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) return FINNExampleOverlay(filename, driver_mode, _mnist_fc_io_shape_dict) -def tfc_w2a2_mnist(target_platform=None): +def tfc_w2a2_mnist(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "tfc-w2a2" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) return FINNExampleOverlay(filename, driver_mode, _mnist_fc_io_shape_dict) -def cnv_w1a1_cifar10(target_platform=None): +def cnv_w1a1_cifar10(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "cnv-w1a1" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) return FINNExampleOverlay(filename, driver_mode, _cifar10_cnv_io_shape_dict) -def cnv_w1a2_cifar10(target_platform=None): +def cnv_w1a2_cifar10(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "cnv-w1a2" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) return FINNExampleOverlay(filename, driver_mode, _cifar10_cnv_io_shape_dict) -def cnv_w2a2_cifar10(target_platform=None): +def cnv_w2a2_cifar10(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "cnv-w2a2" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) return FINNExampleOverlay(filename, driver_mode, _cifar10_cnv_io_shape_dict) -def bincop_cnv(target_platform=None): +def bincop_cnv(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "bincop-cnv" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) return FINNExampleOverlay(filename, driver_mode, _bincop_cnv_io_shape_dict) -def mobilenetv1_w4a4_imagenet(target_platform=None): +def mobilenetv1_w4a4_imagenet(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "mobilenetv1-w4a4" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) if target_platform in ["ZCU104"]: runtime_weight_dir = find_runtime_weights(model_name, target_platform) else: @@ -350,11 +353,11 @@ def mobilenetv1_w4a4_imagenet(target_platform=None): ) -def resnet50_w1a2_imagenet(target_platform=None): +def resnet50_w1a2_imagenet(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "resnet50-w1a2" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) runtime_weight_dir = find_runtime_weights(model_name, target_platform) return FINNExampleOverlay( filename, @@ -363,11 +366,11 @@ def resnet50_w1a2_imagenet(target_platform=None): runtime_weight_dir=runtime_weight_dir, ) -def vgg10_w4a4_radioml(target_platform=None): +def vgg10_w4a4_radioml(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "radioml_w4a4_small_tidy" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) fclk_mhz = 250.0 return FINNExampleOverlay( filename, @@ -376,19 +379,19 @@ def vgg10_w4a4_radioml(target_platform=None): fclk_mhz=fclk_mhz, ) -def mlp_w2a2_unsw_nb15(target_platform=None): +def mlp_w2a2_unsw_nb15(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "unsw_nb15-mlp-w2a2" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) fclk_mhz = 100.0 return FINNExampleOverlay( filename, driver_mode, _unsw_nb15_mlp_io_shape_dict, fclk_mhz=fclk_mhz ) -def cnv_w1a1_gtsrb(target_platform=None): +def cnv_w1a1_gtsrb(target_platform=None, bitfile_path=None): target_platform = resolve_target_platform(target_platform) driver_mode = get_driver_mode() model_name = "cnv-gtsrb-w1a1" - filename = find_bitfile(model_name, target_platform) + filename = find_bitfile(model_name, target_platform, bitfile_path) return FINNExampleOverlay(filename, driver_mode, _gtsrb_cnv_io_shape_dict) \ No newline at end of file