Skip to content

Commit

Permalink
display is buggy, but this is general idea
Browse files Browse the repository at this point in the history
  • Loading branch information
metazool committed Aug 7, 2024
1 parent e7d3030 commit 9cb2068
Showing 1 changed file with 36 additions and 11 deletions.
47 changes: 36 additions & 11 deletions cyto_ml/visualisation/pages/02_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
cached_image,
)

DEPTH = 8


@st.cache_resource
def kmeans_cluster() -> KMeans:
"""
K-means cluster the embeddings, option for default size
K-means cluster the embeddings, option in session for default size
"""
print("model")
X = image_embeddings("plankton")
n_clusters = st.session_state["n_clusters"]
# Initialize and fit KMeans
Expand All @@ -25,7 +26,7 @@ def kmeans_cluster() -> KMeans:
@st.cache_data
def image_labels() -> dict:
"""
TODO good form to move all this into cyto_ml, call from there
TODO good form to move all this into cyto_ml, call from there?
"""
km = kmeans_cluster()
clusters = dict(zip(set(km.labels_), [[] for _ in range(len(set(km.labels_)))]))
Expand All @@ -36,36 +37,60 @@ def image_labels() -> dict:
return clusters


def show_cluster():
def add_more() -> None:
st.session_state["depth"] += DEPTH


def do_less() -> None:
st.session_state["depth"] -= DEPTH


def show_cluster() -> None:

# TODO n_clusters configurable with selector
fitted = image_labels()
closest = fitted[st.session_state["cluster"]]

# seems backwards, something in session state?
rows = []
for _ in range(0, 8):
rows.append(st.columns(8))
for index, _ in enumerate(rows):
for c in rows[index]:
# TODO figure out why this renders twice
for _ in range(0, st.session_state["depth"]):
cols = st.columns(DEPTH)
for c in cols:
c.image(cached_image(closest.pop()), width=60)


# TODO some visualisation, actual content, etc
def main() -> None:

# start with this cluster label
if "cluster" not in st.session_state:
st.session_state["cluster"] = 1

# start kmeans with this many target clusters
if "n_clusters" not in st.session_state:
st.session_state["n_clusters"] = 5

# show this many images * 8 across
if "depth" not in st.session_state:
st.session_state["depth"] = 8

st.selectbox(
"cluster",
"cluster label",
[x for x in range(0, st.session_state["n_clusters"])],
key="cluster",
on_change=show_cluster,
)

st.selectbox(
"n_clusters",
[3, 5, 8],
key="n_clusters",
on_change=kmeans_cluster,
)

st.button("more", on_click=add_more)

st.button("less", on_click=do_less)

show_cluster()


Expand Down

0 comments on commit 9cb2068

Please sign in to comment.