Skip to content

Commit

Permalink
feat: allow multiple authentication methods
Browse files Browse the repository at this point in the history
  • Loading branch information
pquadri committed Sep 19, 2024
1 parent fb4fb03 commit 819a4c8
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
4 changes: 2 additions & 2 deletions snowflake_utils/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import typer
from typing_extensions import Annotated

from ..snowflake_utils.settings import SnowflakeSettings
from .models import Column, FileFormat, InlineFileFormat, Schema, Table
from .settings import SnowflakeSettings

app = typer.Typer()

Expand Down Expand Up @@ -43,7 +43,7 @@ def mass_single_column_update(
new_column = Column(name=new_column, data_type=data_type)
log_level = os.getenv("LOG_LEVEL", "INFO")
logging.getLogger("snowflake-utils").setLevel(log_level)
with SnowflakeSettings.connect() as conn, conn.cursor() as cursor:
with SnowflakeSettings().connect() as conn, conn.cursor() as cursor:
tables = db_schema.get_tables(cursor=cursor)
for table in tables:
columns = table.get_columns(cursor=cursor)
Expand Down
37 changes: 30 additions & 7 deletions snowflake_utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,52 @@
from enum import Enum
from typing import Annotated

from pydantic import StringConstraints
from pydantic_settings import BaseSettings, SettingsConfigDict
from snowflake.connector import SnowflakeConnection
from snowflake.connector import connect as _connect


class Authenticator(str, Enum):
snowflake = "snowflake"
externalbrowser = "externalbrowser"
username_password_mfa = "username_password_mfa"

def __str__(self) -> str:
return self.value


OktaDomain = Annotated[
str,
StringConstraints(pattern=r"https://.*\.okta\.com"),
]


class SnowflakeSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="SNOWFLAKE_")

account: str = "snowflake-test"
user: str = "snowlfake"
password: str = "snowlfake"
db: str = "snowlfake"
role: str = "snowlfake"
warehouse: str = "snowlfake"
user: str = "snowflake"
password: str = "snowflake"
db: str = "snowflake"
role: str = "snowflake"
warehouse: str = "snowflake"
authenticator: Authenticator | OktaDomain = Authenticator.snowflake
_schema: str | None = None

def creds(self) -> dict[str, str | None]:
return {
base_creds = {
"account": self.account,
"user": self.user,
"password": self.password,
"database": self.db,
"schema": self._schema,
"role": self.role,
"warehouse": self.warehouse,
"authenticator": str(self.authenticator),
}
if self.authenticator in (Authenticator.externalbrowser):
return base_creds
return base_creds | {"password": self.password}

def connect(self) -> SnowflakeConnection:
return _connect(**self.creds())
Expand Down
23 changes: 23 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
from pydantic import ValidationError

from snowflake_utils.settings import SnowflakeSettings


@pytest.mark.parametrize(
"authenticator",
[
"snowflake",
"externalbrowser",
"username_password_mfa",
"valid.okta.com",
],
)
def test_authenticator(authenticator: str) -> None:
settings = SnowflakeSettings(authenticator=authenticator)
assert settings.authenticator == authenticator


def test_authenticator_invalid() -> None:
with pytest.raises(ValidationError):
SnowflakeSettings(authenticator="invalid")

0 comments on commit 819a4c8

Please sign in to comment.