diff --git a/ibis-server/app/model/metadata/factory.py b/ibis-server/app/model/metadata/factory.py index e5c36256c..d7810c70a 100644 --- a/ibis-server/app/model/metadata/factory.py +++ b/ibis-server/app/model/metadata/factory.py @@ -12,6 +12,7 @@ from app.model.metadata.mssql import MSSQLMetadata from app.model.metadata.mysql import MySQLMetadata from app.model.metadata.postgres import PostgresMetadata +from app.model.metadata.trino import TrinoMetadata class MetadataFactory: @@ -29,6 +30,8 @@ def get_metadata(self, data_source: DataSource, connection_info) -> Metadata: return MSSQLMetadata(connection_info) if data_source == DataSource.clickhouse: return ClickHouseMetadata(connection_info) + if data_source == DataSource.trino: + return TrinoMetadata(connection_info) raise NotImplementedError(f"Unsupported data source: {self}") diff --git a/ibis-server/app/model/metadata/trino.py b/ibis-server/app/model/metadata/trino.py new file mode 100644 index 000000000..65d04ca3a --- /dev/null +++ b/ibis-server/app/model/metadata/trino.py @@ -0,0 +1,131 @@ +from json import loads +from urllib.parse import urlparse + +import pandas as pd + +from app.model import ConnectionUrl, TrinoConnectionInfo +from app.model.data_source import DataSource +from app.model.metadata.dto import ( + Column, + Constraint, + Table, + TableProperties, + WrenEngineColumnType, +) +from app.model.metadata.metadata import Metadata + + +class TrinoMetadata(Metadata): + def __init__(self, connection_info: TrinoConnectionInfo | ConnectionUrl): + super().__init__(connection_info) + + def get_table_list(self) -> list[Table]: + sql = """SELECT + t.table_catalog, + t.table_schema, + t.table_name, + c.column_name, + c.data_type, + c.is_nullable + FROM + information_schema.tables t + JOIN + information_schema.columns c + ON t.table_schema = c.table_schema + AND t.table_name = c.table_name + WHERE + t.table_type IN ('BASE TABLE', 'VIEW') + AND t.table_schema NOT IN ('information_schema', 'pg_catalog')""" + + sql_cursor = DataSource.trino.get_connection(self.connection_info).raw_sql(sql) + column_names = [col[0] for col in sql_cursor.description] + response = loads( + pd.DataFrame(sql_cursor.fetchall(), columns=column_names).to_json( + orient="records" + ) + ) + unique_tables = {} + for row in response: + # generate unique table name + schema_table = self._format_trino_compact_table_name( + row["table_catalog"], row["table_schema"], row["table_name"] + ) + # init table if not exists + if schema_table not in unique_tables: + unique_tables[schema_table] = Table( + name=schema_table, + description="", + columns=[], + properties=TableProperties( + schema=row["table_schema"], + catalog=row["table_catalog"], + table=row["table_name"], + ), + primaryKey="", + ) + + # table exists, and add column to the table + unique_tables[schema_table].columns.append( + Column( + name=row["column_name"], + type=self._transform_column_type(row["data_type"]), + notNull=row["is_nullable"].lower() == "no", + description="", + properties=None, + ) + ) + return list(unique_tables.values()) + + def get_constraints(self) -> list[Constraint]: + return [] + + def _format_trino_compact_table_name( + self, catalog: str, schema: str, table: str + ) -> str: + return f"{catalog}.{schema}.{table}" + + def _get_schema_name(self): + if hasattr(self.connection_info, "connection_url"): + return urlparse( + self.connection_info.connection_url.get_secret_value() + ).path.split("/")[-1] + else: + return self.connection_info.trino_schema.get_secret_value() + + def _transform_column_type(self, data_type): + # all possible types listed here: https://trino.io/docs/current/language/types.html + switcher = { + # String Types (ignore Binary and Spatial Types for now) + "char": WrenEngineColumnType.CHAR, + "varchar": WrenEngineColumnType.VARCHAR, + "tinytext": WrenEngineColumnType.TEXT, + "text": WrenEngineColumnType.TEXT, + "mediumtext": WrenEngineColumnType.TEXT, + "longtext": WrenEngineColumnType.TEXT, + "enum": WrenEngineColumnType.VARCHAR, + "set": WrenEngineColumnType.VARCHAR, + # Numeric Types(https://dev.mysql.com/doc/refman/8.4/en/numeric-types.html) + "bit": WrenEngineColumnType.TINYINT, + "tinyint": WrenEngineColumnType.TINYINT, + "smallint": WrenEngineColumnType.SMALLINT, + "mediumint": WrenEngineColumnType.INTEGER, + "int": WrenEngineColumnType.INTEGER, + "integer": WrenEngineColumnType.INTEGER, + "bigint": WrenEngineColumnType.BIGINT, + # boolean + "bool": WrenEngineColumnType.BOOLEAN, + "boolean": WrenEngineColumnType.BOOLEAN, + # Decimal + "float": WrenEngineColumnType.FLOAT8, + "double": WrenEngineColumnType.DOUBLE, + "decimal": WrenEngineColumnType.DECIMAL, + "numeric": WrenEngineColumnType.NUMERIC, + # Date and Time Types(https://dev.mysql.com/doc/refman/8.4/en/date-and-time-types.html) + "date": WrenEngineColumnType.DATE, + "datetime": WrenEngineColumnType.TIMESTAMP, + "timestamp": WrenEngineColumnType.TIMESTAMPTZ, + # JSON Type + "json": WrenEngineColumnType.JSON, + } + + return switcher.get(data_type.lower(), WrenEngineColumnType.UNKNOWN) diff --git a/ibis-server/tests/routers/v2/connector/test_trino.py b/ibis-server/tests/routers/v2/connector/test_trino.py index 706fa5bb8..9db588fe2 100644 --- a/ibis-server/tests/routers/v2/connector/test_trino.py +++ b/ibis-server/tests/routers/v2/connector/test_trino.py @@ -371,3 +371,35 @@ def to_connection_info(trino: TrinoContainer): def to_connection_url(trino: TrinoContainer): info = to_connection_info(trino) return f"trino://{info['user']}@{info['host']}:{info['port']}/{info['catalog']}/{info['schema']}" + + +def test_metadata_list_tables(trino: TrinoContainer): + connection_info = to_connection_info(trino) + response = client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + + result = response.json()[0] + assert result["name"] is not None + assert result["columns"] is not None + assert result["primaryKey"] is not None + assert result["description"] is not None + assert result["properties"] is not None + + +def test_metadata_list_constraints(trino: TrinoContainer): + connection_info = to_connection_info(trino) + response = client.post( + url=f"{base_url}/metadata/constraints", + json={ + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + + result = response.json() + assert len(result) == 0