Skip to content

Commit

Permalink
feat: trino integration get tables name api developed (#723)
Browse files Browse the repository at this point in the history
* feat: trino integration added

* fix: all test cases is being passed
  • Loading branch information
himanshu634 authored Aug 2, 2024
1 parent 39e2025 commit 1cb6067
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ibis-server/app/model/metadata/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from app.model.metadata.mssql import MSSQLMetadata
from app.model.metadata.mysql import MySQLMetadata
from app.model.metadata.postgres import PostgresMetadata
from app.model.metadata.trino import TrinoMetadata


class MetadataFactory:
Expand All @@ -29,6 +30,8 @@ def get_metadata(self, data_source: DataSource, connection_info) -> Metadata:
return MSSQLMetadata(connection_info)
if data_source == DataSource.clickhouse:
return ClickHouseMetadata(connection_info)
if data_source == DataSource.trino:
return TrinoMetadata(connection_info)

raise NotImplementedError(f"Unsupported data source: {self}")

Expand Down
131 changes: 131 additions & 0 deletions ibis-server/app/model/metadata/trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from json import loads
from urllib.parse import urlparse

import pandas as pd

from app.model import ConnectionUrl, TrinoConnectionInfo
from app.model.data_source import DataSource
from app.model.metadata.dto import (
Column,
Constraint,
Table,
TableProperties,
WrenEngineColumnType,
)
from app.model.metadata.metadata import Metadata


class TrinoMetadata(Metadata):
def __init__(self, connection_info: TrinoConnectionInfo | ConnectionUrl):
super().__init__(connection_info)

def get_table_list(self) -> list[Table]:
sql = """SELECT
t.table_catalog,
t.table_schema,
t.table_name,
c.column_name,
c.data_type,
c.is_nullable
FROM
information_schema.tables t
JOIN
information_schema.columns c
ON t.table_schema = c.table_schema
AND t.table_name = c.table_name
WHERE
t.table_type IN ('BASE TABLE', 'VIEW')
AND t.table_schema NOT IN ('information_schema', 'pg_catalog')"""

sql_cursor = DataSource.trino.get_connection(self.connection_info).raw_sql(sql)
column_names = [col[0] for col in sql_cursor.description]
response = loads(
pd.DataFrame(sql_cursor.fetchall(), columns=column_names).to_json(
orient="records"
)
)
unique_tables = {}
for row in response:
# generate unique table name
schema_table = self._format_trino_compact_table_name(
row["table_catalog"], row["table_schema"], row["table_name"]
)
# init table if not exists
if schema_table not in unique_tables:
unique_tables[schema_table] = Table(
name=schema_table,
description="",
columns=[],
properties=TableProperties(
schema=row["table_schema"],
catalog=row["table_catalog"],
table=row["table_name"],
),
primaryKey="",
)

# table exists, and add column to the table
unique_tables[schema_table].columns.append(
Column(
name=row["column_name"],
type=self._transform_column_type(row["data_type"]),
notNull=row["is_nullable"].lower() == "no",
description="",
properties=None,
)
)
return list(unique_tables.values())

def get_constraints(self) -> list[Constraint]:
return []

def _format_trino_compact_table_name(
self, catalog: str, schema: str, table: str
) -> str:
return f"{catalog}.{schema}.{table}"

def _get_schema_name(self):
if hasattr(self.connection_info, "connection_url"):
return urlparse(
self.connection_info.connection_url.get_secret_value()
).path.split("/")[-1]
else:
return self.connection_info.trino_schema.get_secret_value()

def _transform_column_type(self, data_type):
# all possible types listed here: https://trino.io/docs/current/language/types.html
switcher = {
# String Types (ignore Binary and Spatial Types for now)
"char": WrenEngineColumnType.CHAR,
"varchar": WrenEngineColumnType.VARCHAR,
"tinytext": WrenEngineColumnType.TEXT,
"text": WrenEngineColumnType.TEXT,
"mediumtext": WrenEngineColumnType.TEXT,
"longtext": WrenEngineColumnType.TEXT,
"enum": WrenEngineColumnType.VARCHAR,
"set": WrenEngineColumnType.VARCHAR,
# Numeric Types(https://dev.mysql.com/doc/refman/8.4/en/numeric-types.html)
"bit": WrenEngineColumnType.TINYINT,
"tinyint": WrenEngineColumnType.TINYINT,
"smallint": WrenEngineColumnType.SMALLINT,
"mediumint": WrenEngineColumnType.INTEGER,
"int": WrenEngineColumnType.INTEGER,
"integer": WrenEngineColumnType.INTEGER,
"bigint": WrenEngineColumnType.BIGINT,
# boolean
"bool": WrenEngineColumnType.BOOLEAN,
"boolean": WrenEngineColumnType.BOOLEAN,
# Decimal
"float": WrenEngineColumnType.FLOAT8,
"double": WrenEngineColumnType.DOUBLE,
"decimal": WrenEngineColumnType.DECIMAL,
"numeric": WrenEngineColumnType.NUMERIC,
# Date and Time Types(https://dev.mysql.com/doc/refman/8.4/en/date-and-time-types.html)
"date": WrenEngineColumnType.DATE,
"datetime": WrenEngineColumnType.TIMESTAMP,
"timestamp": WrenEngineColumnType.TIMESTAMPTZ,
# JSON Type
"json": WrenEngineColumnType.JSON,
}

return switcher.get(data_type.lower(), WrenEngineColumnType.UNKNOWN)
32 changes: 32 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,35 @@ def to_connection_info(trino: TrinoContainer):
def to_connection_url(trino: TrinoContainer):
info = to_connection_info(trino)
return f"trino://{info['user']}@{info['host']}:{info['port']}/{info['catalog']}/{info['schema']}"


def test_metadata_list_tables(trino: TrinoContainer):
connection_info = to_connection_info(trino)
response = client.post(
url=f"{base_url}/metadata/tables",
json={
"connectionInfo": connection_info,
},
)
assert response.status_code == 200

result = response.json()[0]
assert result["name"] is not None
assert result["columns"] is not None
assert result["primaryKey"] is not None
assert result["description"] is not None
assert result["properties"] is not None


def test_metadata_list_constraints(trino: TrinoContainer):
connection_info = to_connection_info(trino)
response = client.post(
url=f"{base_url}/metadata/constraints",
json={
"connectionInfo": connection_info,
},
)
assert response.status_code == 200

result = response.json()
assert len(result) == 0

0 comments on commit 1cb6067

Please sign in to comment.