Skip to content

Commit

Permalink
Merge branch 'main' into 612-introduced-recap-search-alerts
Browse files Browse the repository at this point in the history
  • Loading branch information
albertisfu committed Jul 25, 2024
2 parents a468336 + 9b9bc1b commit 38d6884
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 139 deletions.
159 changes: 106 additions & 53 deletions cl/corpus_importer/management/commands/make_aws_manifest_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,63 @@
s3_client = boto3.client("s3")


def get_total_number_of_records(type: str, use_replica: bool = False) -> int:
def get_total_number_of_records(type: str, options: dict[str, Any]) -> int:
"""
Retrieves the total number of records for a specific data type.
Args:
type (str): The type of data to count. Must be one of the valid values
from the `SEARCH_TYPES` class.
use_replica (bool, optional): Whether to use the replica database
connection (default: False).
options (dict[str, Any]): A dictionary containing options for filtering
the results.
- 'use_replica' (bool, optional): Whether to use the replica database
connection (default: False).
- 'random_sample_percentage' (float, optional): The percentage of
records to include in a random sample.
Returns:
int: The total number of records matching the specified data type.
"""
match type:
case SEARCH_TYPES.RECAP_DOCUMENT:
query = """
SELECT count(*) AS exact_count
FROM search_recapdocument
base_query = (
"SELECT count(*) AS exact_count FROM search_recapdocument"
)
filter_clause = """
WHERE is_available=True AND page_count>0 AND ocr_status!=1
"""
case SEARCH_TYPES.OPINION:
query = """
SELECT count(*) AS exact_count
FROM search_opinion
WHERE extracted_by_ocr != true
"""
base_query = "SELECT count(*) AS exact_count FROM search_opinion"
filter_clause = "WHERE extracted_by_ocr != true"
case SEARCH_TYPES.ORAL_ARGUMENT:
query = """
SELECT count(*) AS exact_count
FROM audio_audio
WHERE
local_path_mp3 != '' AND
base_query = "SELECT count(*) AS exact_count FROM audio_audio"
filter_clause = """WHERE local_path_mp3 != '' AND
download_url != 'https://www.cadc.uscourts.gov/recordings/recordings.nsf/' AND
position('Unavailable' in download_url) = 0 AND
duration > 30
"""

if options["random_sample_percentage"]:
percentage = options["random_sample_percentage"]
base_query = f"{base_query} TABLESAMPLE SYSTEM ({percentage})"

query = (
f"{base_query}\n"
if options["all_records"]
else f"{base_query}\n {filter_clause}\n"
)
with connections[
"replica" if use_replica else "default"
"replica" if options["use_replica"] else "default"
].cursor() as cursor:
cursor.execute(query, [])
result = cursor.fetchone()

return int(result[0])


def get_custom_query(type: str, last_pk: str) -> tuple[str, list[Any]]:
def get_custom_query(
type: str, last_pk: str, options: dict[str, Any]
) -> tuple[str, list[Any]]:
"""
Generates a custom SQL query based on the provided type and optional last
pk.
Expand All @@ -69,57 +79,59 @@ def get_custom_query(type: str, last_pk: str) -> tuple[str, list[Any]]:
type (str): Type of data to retrieve.
last_pk (int, optional): Last primary key retrieved in a previous
query. Defaults to None.
options (dict[str, Any]): A dictionary containing options for filtering
the results.
- 'random_sample_percentage' (float, optional): The percentage of
records to include in a random sample.
Returns:
tuple[str, list[Any]]: A tuple containing the constructed SQL
query(str) and a list of parameters (list[Any]) to be used with
the query.
"""
params = []

random_sample = options["random_sample_percentage"]
match type:
case SEARCH_TYPES.RECAP_DOCUMENT:
base_query = "SELECT id from search_recapdocument"
filter_clause = (
"WHERE is_available=True AND page_count>0 AND ocr_status!=1"
if not last_pk
else (
"WHERE id > %s AND is_available = True AND page_count > 0"
" AND ocr_status != 1"
)
)
case SEARCH_TYPES.OPINION:
base_query = "SELECT id from search_opinion"
filter_clause = (
"WHERE extracted_by_ocr != true"
if not last_pk
else "WHERE id > %s AND extracted_by_ocr != true"
)
filter_clause = "WHERE extracted_by_ocr != true"
case SEARCH_TYPES.ORAL_ARGUMENT:
base_query = "SELECT id from audio_audio"
no_argument_where_clause = """
filter_clause = """
WHERE local_path_mp3 != '' AND
download_url != 'https://www.cadc.uscourts.gov/recordings/recordings.nsf/' AND
position('Unavailable' in download_url) = 0 AND
duration > 30
"""
where_clause_with_argument = """
WHERE id > %s AND
local_path_mp3 != '' AND
download_url != 'https://www.cadc.uscourts.gov/recordings/recordings.nsf/' AND
position('Unavailable' in download_url) = 0 AND
duration > 30
"""
filter_clause = (
no_argument_where_clause
if not last_pk
else where_clause_with_argument
)

if last_pk:
if random_sample:
base_query = f"{base_query} TABLESAMPLE SYSTEM ({random_sample})"

if options["all_records"]:
filter_clause = ""

# Using a WHERE clause with `id > last_pk` and a LIMIT clause for batch
# retrieval is not suitable for random sampling. The following logic
# removes these clauses when retrieving a random sample to ensure all rows
# have an equal chance of being selected.
if last_pk and not random_sample:
filter_clause = (
f"WHERE id > %s"
if not filter_clause
else f"{filter_clause} AND id > %s"
)
params.append(last_pk)

query = f"{base_query}\n {filter_clause}\n ORDER BY id\n LIMIT %s"
query = (
f"{base_query}\n {filter_clause}"
if random_sample
else f"{base_query}\n {filter_clause}\n ORDER BY id\n LIMIT %s"
)

return query, params

Expand Down Expand Up @@ -170,6 +182,27 @@ def add_arguments(self, parser: CommandParser):
default=False,
help="Use this flag to run the queries in the replica db",
)
parser.add_argument(
"--file-name",
type=str,
default=None,
help="Custom name for the output files. If not provided, a default "
"name will be used.",
)
parser.add_argument(
"--random-sample-percentage",
type=float,
default=None,
help="Specifies the proportion of the table to be sampled (between "
"0.0 and 100.0). Use this flag to retrieve a random set of records.",
)
parser.add_argument(
"--all-records",
action="store_true",
default=False,
help="Use this flag to retrieve all records from the table without"
" applying any filters.",
)

def handle(self, *args, **options):
r = get_redis_interface("CACHE")
Expand All @@ -188,7 +221,7 @@ def handle(self, *args, **options):
)
if not total_number_of_records:
total_number_of_records = get_total_number_of_records(
record_type, options["use_replica"]
record_type, options
)
r.hset(
f"{record_type}_import_status",
Expand All @@ -200,12 +233,17 @@ def handle(self, *args, **options):
r.hget(f"{record_type}_import_status", "next_iteration_counter")
or 0
)
file_name = (
options["file_name"]
if options["file_name"]
else f"{record_type}_filelist"
)
while True:
query, params = get_custom_query(
options["record_type"],
last_pk,
options["record_type"], last_pk, options
)
params.append(options["query_batch_size"])
if not options["random_sample_percentage"]:
params.append(options["query_batch_size"])

with connections[
"replica" if options["use_replica"] else "default"
Expand All @@ -226,22 +264,37 @@ def handle(self, *args, **options):
extrasaction="ignore",
)
for row in batched(rows, options["lambda_record_size"]):
query_dict = {
"bucket": bucket_name,
"file_name": (
if options["random_sample_percentage"]:
# Create an underscore-separated file name that lambda
# can split and use as part of batch processing.
ids = [str(r[0]) for r in row]
content = "_".join(ids)
else:
content = (
f"{row[0][0]}_{row[-1][0]}"
if len(row) > 1
else f"{row[0][0]}"
),
)
query_dict = {
"bucket": bucket_name,
"file_name": content,
}
writer.writerow(query_dict)

s3_client.put_object(
Key=f"{record_type}_filelist_{counter}.csv",
Key=f"{file_name}_{counter}.csv",
Bucket=bucket_name,
Body=csvfile.getvalue().encode("utf-8"),
)

if options["random_sample_percentage"]:
# Due to the non-deterministic nature of random sampling,
# storing data to recover the query for future executions
# wouldn't be meaningful. Random queries are unlikely to
# produce the same results on subsequent runs.
logger.info(f"Finished processing {record_count} records")
break

counter += 1
last_pk = rows[-1][0]
records_processed = int(
Expand Down
32 changes: 31 additions & 1 deletion cl/custom_filters/templatetags/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from elasticsearch_dsl import AttrDict, AttrList

from cl.search.constants import ALERTS_HL_TAG, SEARCH_HL_TAG
from cl.search.models import SEARCH_TYPES, Docket, DocketEntry
from cl.search.models import SEARCH_TYPES, Court, Docket, DocketEntry

register = template.Library()

Expand Down Expand Up @@ -297,3 +297,33 @@ def alerts_supported(context: RequestContext, search_type: str) -> str:
and waffle.flag_is_active(request, "recap-alerts-active")
)
)


@register.filter
def group_courts(courts: list[Court], num_columns: int) -> list:
"""Divide courts in equal groupings while keeping related courts together
:param courts: Courts to group.
:param num_columns: Number of groups wanted
:return: The courts grouped together
"""

column_len = len(courts) // num_columns
remainder = len(courts) % num_columns

groups = []
start = 0
for index in range(num_columns):
# Calculate the end index for this chunk
end = start + column_len + (1 if index < remainder else 0)

# Find the next COLR as a starting point (Court of last resort)
COLRs = [Court.TERRITORY_SUPREME, Court.STATE_SUPREME]
while end < len(courts) and courts[end].jurisdiction not in COLRs:
end += 1

# Create the column and add it to result
groups.append(courts[start:end])
start = end

return groups
3 changes: 3 additions & 0 deletions cl/lib/command_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def handle(self, *args, **options):
logger.setLevel(logging.INFO)
elif verbosity > 1:
logger.setLevel(logging.DEBUG)
# This will make juriscraper's logger accept most logger calls.
juriscraper_logger = logging.getLogger("juriscraper")
juriscraper_logger.setLevel(logging.DEBUG)


class CommandUtils:
Expand Down
25 changes: 6 additions & 19 deletions cl/lib/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ def merge_form_with_courts(
}
bap_bundle = []
b_bundle = []
state_bundle: List = []
state_bundles = []
states = []
territories = []
for court in courts:
if court.jurisdiction == Court.FEDERAL_APPELLATE:
court_tabs["federal"].append(court)
Expand All @@ -247,15 +247,9 @@ def merge_form_with_courts(
else:
b_bundle.append(court)
elif court.jurisdiction in Court.STATE_JURISDICTIONS:
# State courts get bundled by supreme courts
if court.jurisdiction == Court.STATE_SUPREME:
# Whenever we hit a state supreme court, we append the
# previous bundle and start a new one.
if state_bundle:
state_bundles.append(state_bundle)
state_bundle = [court]
else:
state_bundle.append(court)
states.append(court)
elif court.jurisdiction in Court.TERRITORY_JURISDICTIONS:
territories.append(court)
elif court.jurisdiction in [
Court.FEDERAL_SPECIAL,
Court.COMMITTEE,
Expand All @@ -265,18 +259,11 @@ def merge_form_with_courts(
]:
court_tabs["special"].append(court)

# append the final state bundle after the loop ends. Hack?
state_bundles.append(state_bundle)

# Put the bankruptcy bundles in the courts dict
if bap_bundle:
court_tabs["bankruptcy_panel"] = [bap_bundle]
court_tabs["bankruptcy"] = [b_bundle]

# Divide the state bundles into the correct partitions
court_tabs["state"].append(state_bundles[:17])
court_tabs["state"].append(state_bundles[17:34])
court_tabs["state"].append(state_bundles[34:])
court_tabs["state"] = [states, territories]

return court_tabs, court_count_human, court_count

Expand Down
Loading

0 comments on commit 38d6884

Please sign in to comment.