Skip to content

Commit

Permalink
Merge pull request #16 from PyCampES/heuristic
Browse files Browse the repository at this point in the history
Heuristic
  • Loading branch information
gilgamezh authored Apr 1, 2024
2 parents fd3d660 + fd2e364 commit 5bb47c6
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 124 deletions.
67 changes: 32 additions & 35 deletions src/ficamp/__main__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import argparse
import json
import os
import shutil
from enum import StrEnum
from collections import defaultdict

import questionary
from dotenv import load_dotenv
from sqlmodel import Session, SQLModel, create_engine, select

from ficamp.classifier.infer import infer_tx_category
from ficamp.classifier.keywords import sort_by_keyword_matches
from ficamp.classifier.preprocessing import preprocess
from ficamp.datastructures import Tx
from ficamp.parsers.abn import AbnParser
from ficamp.parsers.bbva import AccountBBVAParser, CreditCardBBVAParser
from ficamp.parsers.bsabadell import AccountBSabadellParser, CreditCardBSabadellParser
from ficamp.parsers.caixabank import CaixaBankParser
from ficamp.parsers.enums import BankParser


Expand All @@ -37,14 +30,13 @@ def cli() -> argparse.Namespace:
default="abn",
help="Specify the bank for the import",
)
import_parser.add_argument("filename", help="File to load")
import_parser.add_argument("--filename", help="File to load")
import_parser.set_defaults(func=import_data)

# Subparser for the categorize command
categorize_parser = subparsers.add_parser(
"categorize", help="Categorize transactions"
)
categorize_parser.add_argument("--infer-category", action="store_true")
categorize_parser.set_defaults(func=categorize)

args = parser.parse_args()
Expand Down Expand Up @@ -80,25 +72,34 @@ class DefaultAnswers:
NEW = "Type a new category"


def query_business_category(tx, session, infer_category=False):
# first try to get from the category_dict
def make_map_cat_to_kws(session):
statement = select(Tx).where(Tx.category.is_not(None))
known_cat_tx = session.exec(statement).all()
keywords = defaultdict(list)
for tx in known_cat_tx:
keywords[tx.category].extend(tx.concept_clean.split())
return keywords


def query_business_category(tx, session):
# Clean up the transaction concept string
tx.concept_clean = preprocess(tx.concept)

# If there is an exact match to the known transactions, return that one
statement = select(Tx.category).where(Tx.concept_clean == tx.concept_clean)
category = session.exec(statement).first()
if category:
return category
# ask the user if we don't know it
# query each time to update
statement = select(Tx.category).where(Tx.category.is_not(None)).distinct()
categories_choices = session.exec(statement).all()

# Build map of category --> keywords
cats = make_map_cat_to_kws(session)
cats_sorted_by_matches = sort_by_keyword_matches(cats, tx.concept_clean)
# Show categories to user sorted by keyword criterion
categories_choices = [cat for _, cat in cats_sorted_by_matches]
categories_choices.extend([DefaultAnswers.NEW, DefaultAnswers.SKIP])
default_choice = DefaultAnswers.SKIP
if infer_category:
inferred_category = infer_tx_category(tx)
if inferred_category:
categories_choices.append(inferred_category)
default_choice = inferred_category
print(f"{tx.date.isoformat()} {tx.amount} {tx.concept_clean}")
default_choice = categories_choices[0]

print(f"{tx.date.isoformat()} | {tx.amount} | {tx.concept_clean}")
answer = questionary.select(
"Please select the category for this TX",
choices=categories_choices,
Expand All @@ -115,17 +116,16 @@ def query_business_category(tx, session, infer_category=False):
return answer


def categorize(args, engine):
"""Function to categorize transactions."""
def categorize(engine):
"""Classify transactions into categories"""
try:
with Session(engine) as session:
statement = select(Tx).where(Tx.category.is_(None))
results = session.exec(statement).all()
print(f"Got {len(results)} Tx to categorize")
for tx in results:
print(f"Processing {tx}")
tx_category = query_business_category(
tx, session, infer_category=args.infer_category)
tx_category = query_business_category(tx, session)
if tx_category:
print(f"Saving category for {tx.concept}: {tx_category}")
tx.category = tx_category
Expand All @@ -135,19 +135,16 @@ def categorize(args, engine):
else:
print("Not saving any category for thi Tx")
except KeyboardInterrupt:
print("Closing")
print("Session interrupted. Closing.")


def main():
# create DB
engine = create_engine("sqlite:///ficamp.db")
# create tables
SQLModel.metadata.create_all(engine)

engine = create_engine("sqlite:///ficamp.db") # create DB
SQLModel.metadata.create_all(engine) # create tables
try:
args = cli()
if args.command:
args.func(args, engine)
args.func(engine)
except KeyboardInterrupt:
print("\nClosing")

Expand Down
Empty file removed src/ficamp/classifier/encoding.py
Empty file.
88 changes: 0 additions & 88 deletions src/ficamp/classifier/features.py

This file was deleted.

Empty file.
14 changes: 14 additions & 0 deletions src/ficamp/classifier/keywords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Logic to sort transactions based on keywords.
"""
import json
import pathlib


def sort_by_keyword_matches(categories: dict, description: str) -> list[str]:
description = description.lower()
matches = []
for category, keywords in categories.items():
n_matches = sum(keyword in description for keyword in keywords)
matches.append((n_matches, category))
return sorted(matches, reverse=True)
1 change: 0 additions & 1 deletion src/ficamp/classifier/payment_method.py

This file was deleted.

27 changes: 27 additions & 0 deletions src/ficamp/classifier/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import string


def remove_digits(s: str) -> str:
"""
Return string without words that have more that 2 digits.
Expand Down Expand Up @@ -29,6 +32,27 @@ def remove_comma(s: str) -> str:
return " ".join(s.split(","))


def remove_punctuation(s: str) -> str:
punctuation = set(string.punctuation)
out = "".join((" " if char in punctuation else char for char in s))
return " ".join(out.split()) # Remove double spaces


def remove_isolated_digits(s: str) -> str:
"""Remove words made only of digits"""
digits = set(string.digits)
clean = []
for word in s.split():
if not all((char in digits for char in word)):
clean.append(word)
return " ".join(clean)


def remove_short_words(s: str) -> str:
"""Remove words made only of digits"""
return " ".join((word for word in s.split() if len(word) >= 2))


def preprocess(s: str) -> str:
"Clean up transaction description"
steps = (
Expand All @@ -37,6 +61,9 @@ def preprocess(s: str) -> str:
remove_colon,
remove_comma,
remove_digits,
remove_punctuation,
remove_isolated_digits,
remove_short_words,
)
out = s
for func in steps:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
remove_comma,
remove_digits,
remove_pipes,
remove_punctuation,
remove_isolated_digits,
remove_short_words,
)


Expand Down Expand Up @@ -56,6 +59,41 @@ def test_remove_comma(inp, exp):
assert remove_comma(inp) == exp


@pytest.mark.parametrize(
("inp,exp"),
(
("hello world", "hello world"),
("hello/world", "hello world"),
("hello.world", "hello world"),
("hello.(.world))", "hello world"),
),
)
def test_remove_punctuation(inp, exp):
assert remove_punctuation(inp) == exp


@pytest.mark.parametrize(
("inp,exp"),
(
("hello22 world", "hello22 world"),
("hello 22 world", "hello world"),
),
)
def test_remove_isolated_digits(inp, exp):
assert remove_isolated_digits(inp) == exp


@pytest.mark.parametrize(
("inp,exp"),
(
("hello a world", "hello world"),
("hello aa world", "hello aa world"),
),
)
def test_remove_short_words(inp, exp):
assert remove_short_words(inp) == exp


@pytest.mark.parametrize(
("inp,exp"),
(
Expand All @@ -70,6 +108,9 @@ def test_remove_comma(inp, exp):
("SEPA 1231|AMSTERDAM 123BIC", "sepa amsterdam"),
("CSID:NL0213324324324", "csid"),
("CSID:NL0213324324324 HELLO,world1332", "csid hello"),
("CSID:NL021332432 N26 HELLO,world1332", "csid n26 hello"),
("CSID:NL021332432 4324 HELLO,world1332", "csid hello"),
("CSID:NL021332432 n. HELLO,world1332", "csid hello"),
),
)
def test_preprocess(inp, exp):
Expand Down

0 comments on commit 5bb47c6

Please sign in to comment.