Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow training a model on multiple annotations #8071

Merged
merged 22 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b6aefe3
implement train-model for multiple annotations (WIP)
philippotto Sep 9, 2024
618cd38
further refactoring of train model form
philippotto Sep 11, 2024
8a8a0cc
make onFinish work for multi-annotation training
philippotto Sep 11, 2024
89d1c78
Merge branch 'master' of github.com:scalableminds/webknossos into mul…
philippotto Sep 11, 2024
7e4e3ea
clean up
philippotto Sep 11, 2024
a927a7e
fix annotationId and csv validation
philippotto Sep 11, 2024
9167b0e
fix training model from annotation view
philippotto Sep 11, 2024
79f859d
Merge branch 'master' into multi-anno-training
philippotto Sep 12, 2024
732a22f
clean up
philippotto Sep 12, 2024
b405d11
Merge branch 'master' of github.com:scalableminds/webknossos into mul…
Sep 23, 2024
6eccb7e
WIP: Apply feedback
Sep 23, 2024
4926361
make bbox checking work while not refetching tracings
Sep 24, 2024
77e0439
some cleanup
Sep 24, 2024
f3b4c16
apply missing pr feedback
Sep 24, 2024
f748d6b
Merge branch 'master' of github.com:scalableminds/webknossos into mul…
Sep 24, 2024
66ecfef
improve ai modal csv input parsing
Sep 24, 2024
e2546b9
Update frontend/javascripts/oxalis/view/jobs/train_ai_model.tsx
MichaelBuessemeyer Sep 24, 2024
acc2b2d
apply feedback
Sep 24, 2024
463de3c
Merge branch 'multi-anno-training' of github.com:scalableminds/webkno…
Sep 24, 2024
75aa6b1
Update frontend/javascripts/oxalis/view/jobs/train_ai_model.tsx
MichaelBuessemeyer Sep 25, 2024
6e378a3
Merge branch 'master' into multi-anno-training
MichaelBuessemeyer Sep 25, 2024
1045a03
Merge branch 'master' into multi-anno-training
MichaelBuessemeyer Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontend/javascripts/admin/api/jobs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ export function startAlignSectionsJob(

type AiModelCategory = "em_neurons" | "em_nuclei";

type AiModelTrainingAnnotationSpecification = {
export type AiModelTrainingAnnotationSpecification = {
annotationId: string;
colorLayerName: string;
segmentationLayerName: string;
Expand Down
93 changes: 86 additions & 7 deletions frontend/javascripts/admin/voxelytics/ai_model_list_view.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import _ from "lodash";
import React, { useState } from "react";
import { SyncOutlined } from "@ant-design/icons";
import { Table, Button, Modal } from "antd";
import { getAiModels } from "admin/admin_rest_api";
import type { AiModel } from "types/api_flow_types";
import { PlusOutlined, SyncOutlined } from "@ant-design/icons";
import { Table, Button, Modal, Space } from "antd";
import { getAiModels, getTracingForAnnotationType } from "admin/admin_rest_api";
import type { AiModel, APIAnnotation, ServerVolumeTracing } from "types/api_flow_types";
import FormattedDate from "components/formatted_date";
import { formatUserName } from "oxalis/model/accessors/user_accessor";
import { useSelector } from "react-redux";
Expand All @@ -12,11 +12,19 @@ import { JobState } from "admin/job/job_list_view";
import { Link } from "react-router-dom";
import { useGuardedFetch } from "libs/react_helpers";
import { PageNotAvailableToNormalUser } from "components/permission_enforcer";
import { type AnnotationWithDataset, TrainAiModelTab } from "oxalis/view/jobs/train_ai_model";
import {
getResolutionInfo,
getSegmentationLayerByName,
} from "oxalis/model/accessors/dataset_accessor";
import { serverVolumeToClientVolumeTracing } from "oxalis/model/reducers/volumetracing_reducer";
import type { Vector3 } from "oxalis/constants";
import type { Key } from "react";

export default function AiModelListView() {
const activeUser = useSelector((state: OxalisState) => state.activeUser);
const [refreshCounter, setRefreshCounter] = useState(0);
const [isTrainModalVisible, setIsTrainModalVisible] = useState(false);
const [aiModels, isLoading] = useGuardedFetch(
getAiModels,
[],
Expand All @@ -30,10 +38,18 @@ export default function AiModelListView() {

return (
<div className="container voxelytics-view">
{isTrainModalVisible ? (
<TrainNewAiJobModal onClose={() => setIsTrainModalVisible(false)} />
) : null}
<div className="pull-right">
<Button onClick={() => setRefreshCounter((val) => val + 1)}>
<SyncOutlined spin={isLoading} /> Refresh
</Button>
<Space>
<Button onClick={() => setIsTrainModalVisible(true)}>
<PlusOutlined /> Train new Model
</Button>
<Button onClick={() => setRefreshCounter((val) => val + 1)}>
<SyncOutlined spin={isLoading} /> Refresh
</Button>
</Space>
</div>
<h3>AI Models</h3>
<Table
Expand Down Expand Up @@ -92,6 +108,69 @@ export default function AiModelListView() {
);
}

function TrainNewAiJobModal({ onClose }: { onClose: () => void }) {
const [annotationsWithDatasets, setAnnotationsWithDatasets] = useState<
AnnotationWithDataset<APIAnnotation>[]
>([]);

const getMagForSegmentationLayer = async (annotationId: string, layerName: string) => {
// The layer name is a human-readable one. It can either belong to an annotationLayer
// (threfore, also to a volume tracing) or to the actual dataset.
// Both are checked below. This won't be ambiguous because annotationLayers must not
// have names that dataset layers already have.

const annotationWithDataset = annotationsWithDatasets.find(({ annotation }) => {
const currentAnnotationId = annotation.id;
return annotationId === currentAnnotationId;
});
if (annotationWithDataset == null) {
throw new Error("Cannot find annotation for specified id.");
}

const { annotation, dataset } = annotationWithDataset;

let annotationLayer = annotation.annotationLayers.find((l) => l.name === layerName);

if (annotationLayer != null) {
const volumeTracing = (await getTracingForAnnotationType(
annotation,
annotationLayer,
)) as ServerVolumeTracing;
const resolutions: Vector3[] = volumeTracing.resolutions?.map(({ x, y, z }) => [x, y, z]) || [
[1, 1, 1],
];
return getResolutionInfo(resolutions).getFinestResolution();
} else {
const segmentationLayer = getSegmentationLayerByName(dataset, layerName);
daniel-wer marked this conversation as resolved.
Show resolved Hide resolved
return getResolutionInfo(segmentationLayer.resolutions).getFinestResolution();
}
};

return (
<Modal
width={875}
open
title={
<>
<i className="fas fa-magic icon-margin-right" />
AI Analysis
</>
}
onCancel={onClose}
footer={null}
>
<TrainAiModelTab
getMagForSegmentationLayer={getMagForSegmentationLayer}
onClose={onClose}
annotationsWithDatasets={annotationsWithDatasets}
onAddAnnotationsWithDatasets={(newItems) => {
setAnnotationsWithDatasets([...annotationsWithDatasets, ...newItems]);
}}
/>
</Modal>
);
}

const renderActionsForModel = (model: AiModel) => {
if (model.trainingJob == null) {
return;
Expand Down
30 changes: 16 additions & 14 deletions frontend/javascripts/components/layer_selection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,28 @@ import { Form, Select } from "antd";
import { getReadableNameOfVolumeLayer } from "oxalis/model/accessors/volumetracing_accessor";
import type { HybridTracing } from "oxalis/store";
import type React from "react";
import type { APIDataLayer } from "types/api_flow_types";
import type { APIAnnotation, APIDataLayer } from "types/api_flow_types";

type LayerSelectionProps = {
type LayerSelectionProps<L extends { name: string }> = {
name: string | Array<string | number>;
chooseSegmentationLayer: boolean;
layers: APIDataLayer[];
tracing: HybridTracing;
layers: L[];
getReadableNameForLayer: (layer: L) => string;
fixedLayerName?: string;
label?: string;
};

export function LayerSelection({
export function LayerSelection<L extends { name: string }>({
layers,
tracing,
getReadableNameForLayer,
fixedLayerName,
layerType,
onChange,
style,
value,
}: {
layers: APIDataLayer[];
tracing: HybridTracing;
layers: L[];
getReadableNameForLayer: (layer: L) => string;
fixedLayerName?: string;
layerType?: string;
style?: React.CSSProperties;
Expand All @@ -49,7 +50,7 @@ export function LayerSelection({
value={value}
>
{layers.map((layer) => {
const readableName = getReadableNameOfVolumeLayer(layer, tracing) || layer.name;
const readableName = getReadableNameForLayer(layer);
return (
<Select.Option key={layer.name} value={layer.name}>
{readableName}
Expand All @@ -60,18 +61,19 @@ export function LayerSelection({
);
}

export function LayerSelectionFormItem({
export function LayerSelectionFormItem<L extends { name: string }>({
name,
chooseSegmentationLayer,
layers,
tracing,
getReadableNameForLayer,
fixedLayerName,
label,
}: LayerSelectionProps): JSX.Element {
}: LayerSelectionProps<L>): JSX.Element {
const layerType = chooseSegmentationLayer ? "segmentation" : "color";
return (
<Form.Item
label={label || "Layer"}
name="layerName"
name={name}
rules={[
{
required: true,
Expand All @@ -85,7 +87,7 @@ export function LayerSelectionFormItem({
layers={layers}
fixedLayerName={fixedLayerName}
layerType={layerType}
tracing={tracing}
getReadableNameForLayer={getReadableNameForLayer}
/>
</Form.Item>
);
Expand Down
2 changes: 1 addition & 1 deletion frontend/javascripts/oxalis/default_state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ const defaultState: OxalisState = {
showVersionRestore: false,
showDownloadModal: false,
showPythonClientModal: false,
aIJobModalState: "invisible",
aIJobModalState: "neuron_inferral",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
aIJobModalState: "neuron_inferral",
aIJobModalState: "invisible",

showRenderAnimationModal: false,
showShareModal: false,
storedLayouts: {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,32 @@ export function getReadableNameByVolumeTracingId(
tracingId: string,
) {
const volumeDescriptor = getVolumeDescriptorById(annotation, tracingId);
return volumeDescriptor.name || "Volume";
return volumeDescriptor.name;
}

export function getSegmentationLayerByHumanReadableName(
dataset: APIDataset,
annotation: APIAnnotation | HybridTracing,
name: string,
) {
try {
const layer = getSegmentationLayerByName(dataset, name);
return layer;
} catch {}

const layer = getVolumeTracingLayers(dataset).find((currentLayer) => {
if (currentLayer.tracingId == null) {
throw new Error("getVolumeTracingLayers must return tracing.");
}
const readableName = getReadableNameByVolumeTracingId(annotation, currentLayer.tracingId);
return readableName === name;
});

if (layer == null) {
throw new Error("Could not find segmentation layer with the name: " + name);
}

return layer;
}

export function getAllReadableLayerNames(dataset: APIDataset, tracing: Tracing) {
Expand All @@ -135,7 +160,7 @@ export function getAllReadableLayerNames(dataset: APIDataset, tracing: Tracing)

export function getReadableNameForLayerName(
dataset: APIDataset,
tracing: Tracing,
tracing: APIAnnotation | HybridTracing,
layerName: string,
): string {
const layer = getLayerByName(dataset, layerName, true);
Expand Down Expand Up @@ -842,7 +867,7 @@ export const getBucketRetrievalSourceFn =

export function getReadableNameOfVolumeLayer(
layer: APIDataLayer,
tracing: HybridTracing,
tracing: APIAnnotation | HybridTracing,
): string | null {
return "tracingId" in layer && layer.tracingId != null
? getReadableNameByVolumeTracingId(tracing, layer.tracingId)
Expand Down
4 changes: 2 additions & 2 deletions frontend/javascripts/oxalis/model_initialization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ function initializeDataset(
const volumeTracings = getServerVolumeTracings(serverTracings);

if (volumeTracings.length > 0) {
const newDataLayers = setupLayerForVolumeTracing(dataset, volumeTracings);
const newDataLayers = getMergedDataLayersFromDatasetAndVolumeTracings(dataset, volumeTracings);
mutableDataset.dataSource.dataLayers = newDataLayers;
validateVolumeLayers(volumeTracings, newDataLayers);
}
Expand Down Expand Up @@ -480,7 +480,7 @@ function initializeDataLayerInstances(gpuFactor: number | null | undefined): {
};
}

function setupLayerForVolumeTracing(
function getMergedDataLayersFromDatasetAndVolumeTracings(
dataset: APIDataset,
tracings: Array<ServerVolumeTracing>,
): Array<APIDataLayer> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ export default function CreateAnimationModalWrapper(props: Props) {
function CreateAnimationModal(props: Props) {
const { isOpen, onClose } = props;
const dataset = useSelector((state: OxalisState) => state.dataset);
const tracing = useSelector((state: OxalisState) => state.tracing);
const activeOrganization = useSelector((state: OxalisState) => state.activeOrganization);

const colorLayers = getColorLayers(dataset);
Expand Down Expand Up @@ -403,7 +402,7 @@ function CreateAnimationModal(props: Props) {
layers={colorLayers}
value={selectedColorLayerName}
onChange={setSelectedColorLayerName}
tracing={tracing}
getReadableNameForLayer={(layer) => layer.name}
style={{ width: "100%" }}
/>
</Col>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,9 @@ function _DownloadModalView({
layers={layers}
value={selectedLayerName}
onChange={setSelectedLayerName}
tracing={tracing}
getReadableNameForLayer={(layer) =>
getReadableNameOfVolumeLayer(layer, tracing) || layer.name
}
style={{ width: "100%" }}
/>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ import { isBoundingBoxExportable } from "./download_modal_view";
import features from "features";
import { setAIJobModalStateAction } from "oxalis/model/actions/ui_actions";
import { InfoCircleOutlined } from "@ant-design/icons";
import { TrainAiModelTab, CollapsibleWorkflowYamlEditor } from "../jobs/train_ai_model";
import {
CollapsibleWorkflowYamlEditor,
TrainAiModelFromAnnotationTab,
} from "../jobs/train_ai_model";
import { LayerSelectionFormItem } from "components/layer_selection";
import { useGuardedFetch } from "libs/react_helpers";
import _ from "lodash";
Expand Down Expand Up @@ -291,7 +294,7 @@ export function StartAIJobModal({ aIJobModalState }: StartAIJobModalProps) {
? {
label: "Train a model",
key: "trainModel",
children: <TrainAiModelTab onClose={onClose} />,
children: <TrainAiModelFromAnnotationTab onClose={onClose} />,
}
: null,
isSuperUser
Expand Down Expand Up @@ -627,11 +630,14 @@ function StartJobForm(props: StartJobFormProps) {
initialName={`${dataset.name}_${props.suggestedDatasetSuffix}`}
/>
<LayerSelectionFormItem
name="layerName"
chooseSegmentationLayer={chooseSegmentationLayer}
label={chooseSegmentationLayer ? "Segmentation Layer" : "Image data layer"}
layers={layers}
fixedLayerName={fixedSelectedLayer?.name}
tracing={tracing}
getReadableNameForLayer={(layer) =>
getReadableNameOfVolumeLayer(layer, tracing) || layer.name
}
/>
<BoundingBoxSelectionFormItem
isBoundingBoxConfigurable={isBoundingBoxConfigurable}
Expand Down
Loading