Skip to content

Commit

Permalink
Add Test for Snapshot Status
Browse files Browse the repository at this point in the history
Signed-off-by: Andre Kurait <akurait@amazon.com>
  • Loading branch information
AndreKurait committed Jun 25, 2024
1 parent 2a45284 commit ae734ee
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from console_link.models.utils import ExitCode
from console_link.environment import Environment
from console_link.models.metrics_source import Component, MetricStatistic
from console_link.models.snapshot import SnapshotStatus
from click.shell_completion import get_completion_class

import logging
Expand Down Expand Up @@ -153,20 +152,18 @@ def snapshot_group(ctx):
def create_snapshot_cmd(ctx, wait, max_snapshot_rate_mb_per_node):
"""Create a snapshot of the source cluster"""
snapshot = ctx.env.snapshot
status, message = logic_snapshot.create(snapshot, wait=wait,
max_snapshot_rate_mb_per_node=max_snapshot_rate_mb_per_node)
if status != SnapshotStatus.COMPLETED and wait:
raise click.ClickException(message)
click.echo(message)
result = logic_snapshot.create(snapshot, wait=wait,
max_snapshot_rate_mb_per_node=max_snapshot_rate_mb_per_node)
click.echo(result.message)


@snapshot_group.command(name="status")
@click.option('--deep-check', is_flag=True, default=False, help='Perform a deep status check of the snapshot')
@click.pass_obj
def status_snapshot_cmd(ctx, deep_check):
"""Check the status of the snapshot"""
status = logic_snapshot.status(ctx.env.snapshot, deep_check=deep_check)
click.echo(f"Snapshot Status: {status}")
result = logic_snapshot.status(ctx.env.snapshot, deep_check=deep_check)
click.echo(result.message)

# ##################### BACKFILL ###################

Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
import logging
from typing import Tuple
from console_link.models.snapshot import Snapshot, SnapshotStatus
from console_link.models.snapshot import Snapshot
from console_link.models.command_result import CommandResult

logger = logging.getLogger(__name__)


def create(snapshot: Snapshot, *args, **kwargs) -> Tuple[SnapshotStatus, str]:
def create(snapshot: Snapshot, *args, **kwargs) -> CommandResult:
logger.info(f"Creating snapshot with {args=} and {kwargs=}")
try:
result = snapshot.create(*args, **kwargs)
return snapshot.create(*args, **kwargs)
except Exception as e:
logger.error(f"Failure running create snapshot: {e}")
return SnapshotStatus.FAILED, f"Failure running create snapshot: {e}"
return CommandResult(status=False, message=f"Failure running create snapshot: {e}")

if not result.success:
return SnapshotStatus.FAILED, "Snapshot creation failed." + "\n" + result.value

return status(snapshot, *args, **kwargs)


def status(snapshot: Snapshot, *args, **kwargs) -> str:
def status(snapshot: Snapshot, *args, **kwargs) -> CommandResult:
logger.info("Getting snapshot status")
return snapshot.status(*args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def create(self, *args, **kwargs) -> CommandResult:
pass

@abstractmethod
def status(self, *args, **kwargs) -> str:
def status(self, *args, **kwargs) -> CommandResult:
"""Get the status of the snapshot."""
pass

Expand Down Expand Up @@ -108,7 +108,10 @@ def create(self, *args, **kwargs) -> CommandResult:
return CommandResult(success=False, value=f"Failed to create snapshot: {str(e)}")

def status(self, *args, **kwargs) -> CommandResult:
return CommandResult(success=False, value="Command not implemented")
deep_check = kwargs.get('deep_check', False)
if deep_check:
return get_snapshot_status_full(self.source_cluster, self.snapshot_name)
return get_snapshot_status(self.source_cluster, self.snapshot_name)


class FileSystemSnapshot(Snapshot):
Expand Down Expand Up @@ -144,11 +147,8 @@ def create(self, *args, **kwargs) -> CommandResult:
logger.error(message)
return CommandResult(success=False, value=message)

def status(self, *args, **kwargs) -> str:
deep_check = kwargs.get('deep_check', False)
if deep_check:
return get_snapshot_status_full(self.source_cluster, self.snapshot_name)
return get_snapshot_status(self.source_cluster, self.snapshot_name)
def status(self, *args, **kwargs) -> CommandResult:
raise NotImplementedError("Status check for FileSystemSnapshot is not implemented yet.")


def parse_args():
Expand All @@ -163,18 +163,21 @@ def parse_args():


def get_snapshot_status(cluster: Cluster, snapshot: str,
repository: str = 'migration_assistant_repo') -> str:
repository: str = 'migration_assistant_repo') -> CommandResult:
path = f"/_snapshot/{repository}/{snapshot}"
response = cluster.call_api(path, HttpMethod.GET)
logging.debug(f"Raw get snapshot status response: {response.text}")
response.raise_for_status()
try:
response = cluster.call_api(path, HttpMethod.GET)
logging.debug(f"Raw get snapshot status response: {response.text}")
response.raise_for_status()

snapshot_data = response.json()
snapshots = snapshot_data.get('snapshots', [])
if not snapshots:
return "Snapshot not started"
snapshot_data = response.json()
snapshots = snapshot_data.get('snapshots', [])
if not snapshots:
return CommandResult(success=False, value="Snapshot not started")

return snapshots[0].get("state")
return CommandResult(success=True, value=snapshots[0].get("state"))
except Exception as e:
return CommandResult(success=False, value=f"Failed to get snapshot status: {str(e)}")


def get_repository_for_snapshot(cluster: Cluster, snapshot: str) -> Optional[str]:
Expand Down Expand Up @@ -248,33 +251,35 @@ def get_snapshot_status_message(snapshot_info: Dict) -> str:
)


def get_snapshot_status_full(cluster: Cluster, snapshot: str, repository: str = 'migration_assistant_repo')\
-> str:
repository = repository if repository != '*' else get_repository_for_snapshot(cluster, snapshot)
def get_snapshot_status_full(cluster: Cluster, snapshot: str,
repository: str = 'migration_assistant_repo') -> CommandResult:
try:
repository = repository if repository != '*' else get_repository_for_snapshot(cluster, snapshot)

path = f"/_snapshot/{repository}/{snapshot}"
response = cluster.call_api(path, HttpMethod.GET)
logging.debug(f"Raw get snapshot status response: {response.text}")
response.raise_for_status()
path = f"/_snapshot/{repository}/{snapshot}"
response = cluster.call_api(path, HttpMethod.GET)
logging.debug(f"Raw get snapshot status response: {response.text}")
response.raise_for_status()

snapshot_data = response.json()
snapshots = snapshot_data.get('snapshots', [])
if not snapshots:
return "Snapshot not started"
snapshot_data = response.json()
snapshots = snapshot_data.get('snapshots', [])
if not snapshots:
return CommandResult(success=False, value="Snapshot not started")

snapshot_info = snapshots[0]
state = snapshot_info.get("state")
snapshot_info = snapshots[0]
state = snapshot_info.get("state")

path = f"/_snapshot/{repository}/{snapshot}/_status"
response = cluster.call_api(path, HttpMethod.GET)
logging.debug(f"Raw get snapshot status full response: {response.text}")
response.raise_for_status()
path = f"/_snapshot/{repository}/{snapshot}/_status"
response = cluster.call_api(path, HttpMethod.GET)
logging.debug(f"Raw get snapshot status full response: {response.text}")
response.raise_for_status()

snapshot_data = response.json()
snapshots = snapshot_data.get('snapshots', [])
if not snapshots:
return "Snapshot not started"
snapshot_data = response.json()
snapshots = snapshot_data.get('snapshots', [])
if not snapshots:
return CommandResult(success=False, value="Snapshot status not available")

snapshot_info = snapshots[0]
message = get_snapshot_status_message(snapshot_info)
return state, message
message = get_snapshot_status_message(snapshot_info)
return CommandResult(success=True, value=f"{state}\n{message}")
except Exception as e:
return CommandResult(success=False, value=f"Failed to get full snapshot status: {str(e)}")
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib
from console_link.models.snapshot import SnapshotStatus

import requests_mock
from console_link.models.command_result import CommandResult

from console_link.cli import cli
from console_link.environment import Environment
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_cli_snapshot_create(runner, env, mocker):
mock = mocker.patch('console_link.logic.snapshot.create')

# Set the mock return value
mock.return_value = SnapshotStatus.COMPLETED, "Snapshot created successfully."
mock.return_value = CommandResult(status=True, message="Snapshot created successfully.")

# Test snapshot creation
result = runner.invoke(cli, ['--config-file', str(VALID_SERVICES_YAML), 'snapshot', 'create'],
Expand All @@ -122,12 +122,11 @@ def test_cli_snapshot_create(runner, env, mocker):
mock.assert_called_once()


@pytest.mark.skip(reason="Not implemented yet")
def test_cli_snapshot_status(runner, env, mocker):
mock = mocker.patch('console_link.logic.snapshot.status')

# Set the mock return value
mock.return_value = SnapshotStatus.COMPLETED, "Snapshot status: COMPLETED"
mock.return_value = CommandResult(status=True, message="Snapshot status: COMPLETED")

# Test snapshot status
result = runner.invoke(cli, ['--config-file', str(VALID_SERVICES_YAML), 'snapshot', 'status'],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,93 @@
from console_link.models.snapshot import S3Snapshot, FileSystemSnapshot, Snapshot
from console_link.environment import get_snapshot
from console_link.models.cluster import AuthMethod
from console_link.models.cluster import AuthMethod, Cluster, HttpMethod
from tests.utils import create_valid_cluster
import pytest
import unittest.mock as mock
from console_link.models.command_result import CommandResult


@pytest.fixture
def mock_cluster():
cluster = mock.Mock(spec=Cluster)
return cluster


@pytest.fixture
def s3_snapshot(mock_cluster):
config = {
"snapshot_name": "test_snapshot",
"s3": {
"repo_uri": "s3://test-bucket",
"aws_region": "us-west-2"
}
}
return S3Snapshot(config, mock_cluster)


def test_s3_snapshot_status(s3_snapshot, mock_cluster):
mock_response = mock.Mock()
mock_response.json.return_value = {
"snapshots": [
{
"snapshot": "test_snapshot",
"state": "SUCCESS"
}
]
}
mock_cluster.call_api.return_value = mock_response

result = s3_snapshot.status()

assert isinstance(result, CommandResult)
assert result.success
assert result.value == "SUCCESS"
mock_cluster.call_api.assert_called_once_with("/_snapshot/migration_assistant_repo/test_snapshot",
HttpMethod.GET)


def test_s3_snapshot_status_full(s3_snapshot, mock_cluster):
mock_response = mock.Mock()
mock_response.json.return_value = {
"snapshots": [
{
"snapshot": "test_snapshot",
"state": "SUCCESS",
"shards_stats": {
"total": 10,
"done": 10,
"failed": 0
},
"stats": {
"total": {
"size_in_bytes": 1000000
},
"processed": {
"size_in_bytes": 1000000
},
"start_time_in_millis": 1625097600000,
"time_in_millis": 60000
}
}
]
}
mock_cluster.call_api.return_value = mock_response

result = s3_snapshot.status(deep_check=True)

assert isinstance(result, CommandResult)
assert result.success
assert "SUCCESS" in result.value
assert "Percent completed: 100.00%" in result.value
assert "Total shards: 10" in result.value
assert "Successful shards: 10" in result.value
assert "Failed shards: 0" in result.value
assert "Start time:" in result.value
assert "Duration:" in result.value
assert "Anticipated duration remaining:" in result.value
assert "Throughput:" in result.value
mock_cluster.call_api.assert_called_with("/_snapshot/migration_assistant_repo/test_snapshot/_status",
HttpMethod.GET)


def test_s3_snapshot_init_succeeds():
Expand Down

0 comments on commit ae734ee

Please sign in to comment.