Skip to content

Commit

Permalink
Merge pull request #92 from Xilinx/update/finn_examples_driver
Browse files Browse the repository at this point in the history
Add argument to provide path to bitfile
  • Loading branch information
auphelia authored May 3, 2024
2 parents c670721 + ad5cf2c commit bfb729a
Showing 1 changed file with 47 additions and 47 deletions.
94 changes: 47 additions & 47 deletions finn_examples/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,24 +190,27 @@ def get_edge_or_pcie():
raise OSError("Platform is not supported.")


def find_bitfile(model_name, target_platform):
bitfile_exts = {"edge": "bit", "pcie": "xclbin"}
bitfile_ext = bitfile_exts[get_edge_or_pcie()]
bitfile_name = "%s.%s" % (model_name, bitfile_ext)
bitfile_candidates = [
pk.resource_filename("finn_examples", "bitfiles/%s/%s" % (target_platform, bitfile_name)),
pk.resource_filename(
"finn_examples",
"bitfiles/bitfiles.zip.d/%s/%s" % (target_platform, bitfile_name),
),
]
for candidate in bitfile_candidates:
if os.path.isfile(candidate):
return candidate
raise Exception(
"Bitfile for model = %s target platform = %s not found. Looked in: %s"
% (model_name, target_platform, str(bitfile_candidates))
)
def find_bitfile(model_name, target_platform, bitfile_path):
if bitfile_path is not None:
return bitfile_path
else:
bitfile_exts = {"edge": "bit", "pcie": "xclbin"}
bitfile_ext = bitfile_exts[get_edge_or_pcie()]
bitfile_name = "%s.%s" % (model_name, bitfile_ext)
bitfile_candidates = [
pk.resource_filename("finn_examples", "bitfiles/%s/%s" % (target_platform, bitfile_name)),
pk.resource_filename(
"finn_examples",
"bitfiles/bitfiles.zip.d/%s/%s" % (target_platform, bitfile_name),
),
]
for candidate in bitfile_candidates:
if os.path.isfile(candidate):
return candidate
raise Exception(
"Bitfile for model = %s target platform = %s not found. Looked in: %s"
% (model_name, target_platform, str(bitfile_candidates))
)


def find_runtime_weights(model_name, target_platform):
Expand Down Expand Up @@ -266,75 +269,75 @@ def resolve_target_platform(target_platform):
return check_platform_is_valid(platform)


def kws_mlp(target_platform=None):
def kws_mlp(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "kwsmlp-w3a3"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
return FINNExampleOverlay(filename, driver_mode, _gscv2_mlp_io_shape_dict)


def tfc_w1a1_mnist(target_platform=None):
def tfc_w1a1_mnist(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "tfc-w1a1"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
return FINNExampleOverlay(filename, driver_mode, _mnist_fc_io_shape_dict)


def tfc_w1a2_mnist(target_platform=None):
def tfc_w1a2_mnist(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "tfc-w1a2"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
return FINNExampleOverlay(filename, driver_mode, _mnist_fc_io_shape_dict)


def tfc_w2a2_mnist(target_platform=None):
def tfc_w2a2_mnist(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "tfc-w2a2"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
return FINNExampleOverlay(filename, driver_mode, _mnist_fc_io_shape_dict)


def cnv_w1a1_cifar10(target_platform=None):
def cnv_w1a1_cifar10(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "cnv-w1a1"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
return FINNExampleOverlay(filename, driver_mode, _cifar10_cnv_io_shape_dict)


def cnv_w1a2_cifar10(target_platform=None):
def cnv_w1a2_cifar10(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "cnv-w1a2"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
return FINNExampleOverlay(filename, driver_mode, _cifar10_cnv_io_shape_dict)


def cnv_w2a2_cifar10(target_platform=None):
def cnv_w2a2_cifar10(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "cnv-w2a2"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
return FINNExampleOverlay(filename, driver_mode, _cifar10_cnv_io_shape_dict)


def bincop_cnv(target_platform=None):
def bincop_cnv(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "bincop-cnv"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
return FINNExampleOverlay(filename, driver_mode, _bincop_cnv_io_shape_dict)


def mobilenetv1_w4a4_imagenet(target_platform=None):
def mobilenetv1_w4a4_imagenet(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "mobilenetv1-w4a4"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
if target_platform in ["ZCU104"]:
runtime_weight_dir = find_runtime_weights(model_name, target_platform)
else:
Expand All @@ -350,11 +353,11 @@ def mobilenetv1_w4a4_imagenet(target_platform=None):
)


def resnet50_w1a2_imagenet(target_platform=None):
def resnet50_w1a2_imagenet(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "resnet50-w1a2"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
runtime_weight_dir = find_runtime_weights(model_name, target_platform)
return FINNExampleOverlay(
filename,
Expand All @@ -363,12 +366,11 @@ def resnet50_w1a2_imagenet(target_platform=None):
runtime_weight_dir=runtime_weight_dir,
)


def vgg10_w4a4_radioml(target_platform=None):
def vgg10_w4a4_radioml(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "radioml_w4a4_small_tidy"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
fclk_mhz = 250.0
return FINNExampleOverlay(
filename,
Expand All @@ -377,21 +379,19 @@ def vgg10_w4a4_radioml(target_platform=None):
fclk_mhz=fclk_mhz,
)


def mlp_w2a2_unsw_nb15(target_platform=None):
def mlp_w2a2_unsw_nb15(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "unsw_nb15-mlp-w2a2"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
fclk_mhz = 100.0
return FINNExampleOverlay(
filename, driver_mode, _unsw_nb15_mlp_io_shape_dict, fclk_mhz=fclk_mhz
)


def cnv_w1a1_gtsrb(target_platform=None):
def cnv_w1a1_gtsrb(target_platform=None, bitfile_path=None):
target_platform = resolve_target_platform(target_platform)
driver_mode = get_driver_mode()
model_name = "cnv-gtsrb-w1a1"
filename = find_bitfile(model_name, target_platform)
filename = find_bitfile(model_name, target_platform, bitfile_path)
return FINNExampleOverlay(filename, driver_mode, _gtsrb_cnv_io_shape_dict)

0 comments on commit bfb729a

Please sign in to comment.