Skip to content

Commit

Permalink
Only try to check the executable running if the platform downloaded i…
Browse files Browse the repository at this point in the history
…s the same we're running in.
  • Loading branch information
fabioz committed Aug 29, 2024
1 parent 562755f commit e8a28b8
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions sema4ai/src/sema4ai_code/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def get_tool_version(tool: Tool, location: str) -> LaunchActionResult:
return launch((location,) + version_command)


def verify_tool_downloaded_ok(tool: Tool, location: str, force: bool) -> bool:
def verify_tool_downloaded_ok(
tool: Tool, location: str, force: bool, make_run_check: bool
) -> bool:
if location in _checked_downloaded_tools and not force:
if os.path.isfile(location):
return True # Already checked: just do simpler check.
Expand All @@ -76,14 +78,17 @@ def verify_tool_downloaded_ok(tool: Tool, location: str, force: bool) -> bool:

# Actually execute it to make sure it works (in windows right after downloading
# it may not be ready, so, retry a few times).
times = 5
timeout = 1
for _ in range(times):
version_result = get_tool_version(tool, location)
if version_result.success:
_checked_downloaded_tools.add(location)
return True
time.sleep(timeout / times)
if not make_run_check:
_checked_downloaded_tools.add(location)
else:
times = 5
timeout = 1
for _ in range(times):
version_result = get_tool_version(tool, location)
if version_result.success:
_checked_downloaded_tools.add(location)
return True
time.sleep(timeout / times)

log.info(f"Tool {location} failed to execute. Details: {version_result.message}")

Expand Down Expand Up @@ -127,7 +132,9 @@ def download_tool(
sys_platform = sys.platform

if not force:
if verify_tool_downloaded_ok(tool, location, force=force):
if verify_tool_downloaded_ok(
tool, location, force=force, make_run_check=sys_platform == sys.platform
):
return

tool_info = get_tool_info(tool)
Expand All @@ -136,7 +143,9 @@ def download_tool(
# If other call was already in progress, we need to check it again,
# as to not overwrite it when force was equal to False.
if not force:
if verify_tool_downloaded_ok(tool, location, force=force):
if verify_tool_downloaded_ok(
tool, location, force=force, make_run_check=sys_platform == sys.platform
):
return

if endpoint is not None:
Expand Down Expand Up @@ -164,7 +173,9 @@ def download_tool(
progress_reporter=progress_reporter,
)

if not verify_tool_downloaded_ok(tool, location, force=True):
if not verify_tool_downloaded_ok(
tool, location, force=True, make_run_check=sys_platform == sys.platform
):
raise Exception(
f"After downloading {tool!r} failed to execute tool (location: {location})."
)
Expand Down

0 comments on commit e8a28b8

Please sign in to comment.