Skip to content

Commit

Permalink
Minor improvemetns in the agent spec validation.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabioz committed Sep 12, 2024
1 parent 5669551 commit 90db2c1
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 38 deletions.
86 changes: 84 additions & 2 deletions sema4ai/src/sema4ai_code/agents/agent_spec_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import typing
import weakref
from dataclasses import dataclass
Expand All @@ -10,7 +11,8 @@
if typing.TYPE_CHECKING:
from tree_sitter import Node, Tree

from sema4ai_code.agents.list_actions_from_agent import ActionPackageInFilesystem
from .list_actions_from_agent import ActionPackageInFilesystem


log = get_logger(__name__)

Expand Down Expand Up @@ -154,11 +156,90 @@ def load_spec(json_spec: dict[str, Any]) -> dict[str, Entry]:
T = TypeVar("T")


class Severity(enum.Enum):
critical = "critical"
warning = "warning"
info = "info"


def _create_range_from_location(
start_line: int,
start_col: int,
end_line: Optional[int] = None,
end_col: Optional[int] = None,
) -> dict:
"""
If the end_line and end_col aren't passed we consider
that the location should go up until the end of the line.
"""
if end_line is None:
assert end_col is None
end_line = start_line + 1
end_col = 0
assert end_col is not None
dct: dict = {
"start": {
"line": start_line,
"character": start_col,
},
"end": {
"line": end_line,
"character": end_col,
},
}
return dct


@dataclass
class Error:
message: str
node: Optional["Node"] = None
code: Optional[ErrorCode] = None
severity: Severity = Severity.critical

def as_diagostic(self, agent_node) -> dict:
from typing import Sequence

use_location: Sequence[int]
error = self

if not error.node:
use_location = (0, 0, 1, 0)
if agent_node is not None:
if agent_node.key.location:
use_location = agent_node.key.location
else:
start_line, start_col = (
error.node.start_point.row,
error.node.start_point.column,
)
end_line, end_col = (
error.node.end_point.row,
error.node.end_point.column,
)
use_location = start_line, start_col, end_line, end_col

use_range = _create_range_from_location(*use_location)

if error.severity == Severity.critical:
severity = 1
elif error.severity == Severity.warning:
severity = 2
elif error.severity == Severity.info:
severity = 3
else:
raise RuntimeError(f"Unexpected severity: {error.severity}")

diagnostic = {
"range": use_range,
"severity": severity,
"source": "sema4ai",
"message": error.message,
}

if error.code:
diagnostic["code"] = error.code.value
return diagnostic


class TreeNode(Generic[T]):
Expand Down Expand Up @@ -788,10 +869,11 @@ def _validate_unreferenced_action_packages(self) -> Iterator[Error]:
if report_error_at_node is not None
else None,
code=ErrorCode.action_package_info_unsynchronized,
severity=Severity.warning,
)

def validate(self, node: "Node") -> Iterator[Error]:
from sema4ai_code.agents.list_actions_from_agent import list_actions_from_agent
from .list_actions_from_agent import list_actions_from_agent

self._action_packages_found_in_filesystem = list_actions_from_agent(
self._agent_root_dir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,8 @@ class _Version:
def validate_agent(
doc: IDocument, agent_node: _EntryNode, version: _Version
) -> Iterator[DiagnosticsTypedDict]:
from typing import Sequence

from sema4ai_code.agents.agent_spec import AGENT_SPEC_V2
from sema4ai_code.agents.agent_spec_handler import Error, validate_from_spec
from sema4ai_code.vendored_deps.yaml_with_location import create_range_from_location

if version.is_v1:
yield from validate_sections_v1(
Expand All @@ -245,36 +242,7 @@ def validate_agent(
pathlib.Path(doc.path).parent,
raise_on_error=False,
):
use_location: Sequence[int]
if not error.node:
use_location = (0, 0, 1, 0)
if agent_node.key.location:
use_location = agent_node.key.location
else:
start_line, start_col = (
error.node.start_point.row,
error.node.start_point.column,
)
end_line, end_col = (
error.node.end_point.row,
error.node.end_point.column,
)
use_location = start_line, start_col, end_line, end_col

use_range = create_range_from_location(*use_location)

diagnostic: DiagnosticsTypedDict

diagnostic = {
"range": use_range,
"severity": DiagnosticSeverity.Error,
"source": "sema4ai",
"message": error.message,
}

if error.code:
diagnostic["code"] = error.code.value
yield diagnostic
yield typing.cast(DiagnosticsTypedDict, error.as_diagostic(agent_node))

else:
raise AssertionError(f"Unexpected version: {version}")
Expand Down
15 changes: 12 additions & 3 deletions sema4ai/src/sema4ai_code/agents/list_actions_from_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def get_as_dict(self) -> dict:
import yaml

try:
if self.is_zip() and self.package_yaml_contents is None:
raise RuntimeError(
"It was not possible to load the agent-spec.yaml from the referenced .zip file."
)
if self.package_yaml_contents is not None:
contents = yaml.safe_load(self.package_yaml_contents)
else:
Expand All @@ -58,10 +62,15 @@ def get_as_dict(self) -> dict:

self._loaded_yaml = contents
return self._loaded_yaml
except Exception:
log.error(f"Error getting {self.package_yaml_path} as yaml.")
except Exception as e:
if self.is_zip():
log.error(
f"Error getting agent-spec.yaml from {self.zip_path} as yaml."
)
else:
log.error(f"Error getting {self.package_yaml_path} as yaml.")

self._loaded_yaml_error = "Unable to load package.yaml as yaml"
self._loaded_yaml_error = str(e)
raise

def get_version(self) -> str:
Expand Down

0 comments on commit 90db2c1

Please sign in to comment.