Skip to content

Commit

Permalink
creates graph on offline data with transformed values in vector
Browse files Browse the repository at this point in the history
  • Loading branch information
NishanthJKumar committed Jul 24, 2023
1 parent 6502fc4 commit a724b23
Showing 1 changed file with 66 additions and 5 deletions.
71 changes: 66 additions & 5 deletions scripts/spot_cube_place_active_sampler_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@ def _vec_to_xy(vec: Array) -> Tuple[float, float]:
return (x, y)


def _vec_to_transformed_vec(vec: Array) -> Tuple[float, float]:
place_robot_xy = math_helpers.Vec2(*vec[-3:-1])

world_fiducial = math_helpers.Vec2(
vec[12], # state.get(surface, "x"),
vec[13], # state.get(surface, "y"),
)
world_to_robot = math_helpers.SE2Pose(
vec[3], # state.get(robot, "x"),
vec[4], # state.get(robot, "y"),
vec[6], # state.get(robot, "yaw"))
)
fiducial_in_robot_frame = world_to_robot.inverse() * world_fiducial
x, y = place_robot_xy - fiducial_in_robot_frame
new_vec = np.copy(vec)
new_vec[-3] = x
new_vec[-2] = y
return new_vec


def _create_image(X: List[Array],
y: List[Array],
classifier: Optional[BinaryClassifier] = None) -> Image:
Expand Down Expand Up @@ -186,7 +206,7 @@ def fit(self, X: Array, y: Array) -> None:
def classify(self, x: Array) -> bool:
# Approximate.
_, y_pt = _vec_to_xy(x)
return y_pt > 0
return y_pt > -0.03 and y_pt < 0.13

def predict_proba(self, x: Array) -> float:
return 1.0 if self.classify(x) else 0.0
Expand Down Expand Up @@ -223,7 +243,35 @@ def _run_sample_efficiency_analysis(X: List[Array], y: List[Array]) -> None:
# "always-true": lambda: _ConstantModel(CFG.seed, True),
"always-false":
lambda: _ConstantModel(CFG.seed, False),
"mlp":
"mlp_unreduced":
lambda: MLPBinaryClassifier(
seed=CFG.seed,
balance_data=CFG.mlp_classifier_balance_data,
max_train_iters=CFG.sampler_mlp_classifier_max_itr,
learning_rate=CFG.learning_rate,
weight_decay=CFG.weight_decay,
use_torch_gpu=CFG.use_torch_gpu,
train_print_every=CFG.pytorch_train_print_every,
n_iter_no_change=CFG.mlp_classifier_n_iter_no_change,
hid_sizes=CFG.mlp_classifier_hid_sizes,
n_reinitialize_tries=CFG.
sampler_mlp_classifier_n_reinitialize_tries,
weight_init="default"),
"mlp_transformed":
lambda: MLPBinaryClassifier(
seed=CFG.seed,
balance_data=CFG.mlp_classifier_balance_data,
max_train_iters=CFG.sampler_mlp_classifier_max_itr,
learning_rate=CFG.learning_rate,
weight_decay=CFG.weight_decay,
use_torch_gpu=CFG.use_torch_gpu,
train_print_every=CFG.pytorch_train_print_every,
n_iter_no_change=CFG.mlp_classifier_n_iter_no_change,
hid_sizes=CFG.mlp_classifier_hid_sizes,
n_reinitialize_tries=CFG.
sampler_mlp_classifier_n_reinitialize_tries,
weight_init="default"),
"mlp_reduced":
lambda: MLPBinaryClassifier(
seed=CFG.seed,
balance_data=CFG.mlp_classifier_balance_data,
Expand Down Expand Up @@ -253,14 +301,27 @@ def _run_sample_efficiency_analysis(X: List[Array], y: List[Array]) -> None:
train_idxs = idxs[num_valid:num_valid + num_training_data]
valid_idxs = idxs[:num_valid]
X_train = np.array([X[i] for i in train_idxs])
X_train_reduced = np.array([_vec_to_xy(X[i]) for i in train_idxs])
X_train_transformed = np.array([_vec_to_transformed_vec(X[i]) for i in train_idxs])
y_train = np.array([y[i] for i in train_idxs])
X_valid = [X[i] for i in valid_idxs]
X_valid_reduced = [np.array(_vec_to_xy(X[i])) for i in valid_idxs]
X_valid_transformed = [_vec_to_transformed_vec(X[i]) for i in valid_idxs]
y_valid = [y[i] for i in valid_idxs]
# Train.
model = create_model()
model.fit(X_train, y_train)
# Predict.
y_pred = [model.classify(x) for x in X_valid]
if model_name == "mlp_reduced":
model.fit(X_train_reduced, y_train)
# Predict.
y_pred = [model.classify(x) for x in X_valid_reduced]
elif model_name == "mlp_transformed":
model.fit(X_train_transformed, y_train)
# Predict
y_pred = [model.classify(x) for x in X_valid_transformed]
else:
model.fit(X_train, y_train)
# Predict.
y_pred = [model.classify(x) for x in X_valid]
acc = np.mean([(y == y_hat)
for y, y_hat in zip(y_valid, y_pred)])
print(f"Trial {i} accuracy: {acc}")
Expand Down

0 comments on commit a724b23

Please sign in to comment.