Skip to content

Commit

Permalink
Black notebooks + bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbalderrama committed Oct 5, 2021
1 parent f358d2b commit 5971f58
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 178 deletions.
97 changes: 47 additions & 50 deletions notebooks/a2r2-01.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
"from torch import Tensor\n",
"from torch.nn import LSTM, Linear, Module, MSELoss\n",
"from torch.optim import Adam\n",
"from torch.utils.data import DataLoader, TensorDataset\n"
"from torch.utils.data import DataLoader, TensorDataset"
],
"outputs": [],
"metadata": {
Expand Down Expand Up @@ -236,7 +236,7 @@
"# classes dataset\n",
"classes_filename = \"classes_filiere.parquet\"\n",
"classes_path = BASE_DIRECTORY.joinpath(classes_filename)\n",
"classes_dataset = load_data(classes_path)\n"
"classes_dataset = load_data(classes_path)"
],
"outputs": [],
"metadata": {
Expand All @@ -260,23 +260,23 @@
"source": [
"# show a dataframe as a table\n",
"def display_dataframe(\n",
" dataframe: DataFrame,\n",
") -> None: \n",
" dataframe: DataFrame,\n",
") -> None:\n",
" if COLAB_ON:\n",
" spec = importlib.util.find_spec(\"google.colab\")\n",
" if spec: \n",
" data_table = importlib.import_module(\"google.colab.data_table\") \n",
" if spec:\n",
" data_table = importlib.import_module(\"google.colab.data_table\")\n",
" enable_dataframe_formatter = getattr(\n",
" data_table, \n",
" data_table,\n",
" \"enable_dataframe_formatter\",\n",
" ) \n",
" \n",
" enable_dataframe_formatter() \n",
" \n",
" #style = dataframe.style.set_caption(\"Dataframe Caption\")\n",
" #display.display(style)\n",
" )\n",
"\n",
" enable_dataframe_formatter()\n",
"\n",
" display.display(dataframe[:20000] if COLAB_ON else dataframe) \n"
" # style = dataframe.style.set_caption(\"Dataframe Caption\")\n",
" # display.display(style)\n",
"\n",
" display.display(dataframe[:20000] if COLAB_ON else dataframe)"
],
"outputs": [],
"metadata": {}
Expand Down Expand Up @@ -331,7 +331,7 @@
"cell_type": "code",
"execution_count": null,
"source": [
"display_dataframe(classes_dataset)\n"
"display_dataframe(classes_dataset)"
],
"outputs": [],
"metadata": {
Expand Down Expand Up @@ -418,8 +418,8 @@
"\n",
"def pre_process_by_aggregation_classes(\n",
" dataframe: DataFrame,\n",
") -> DataFrame: \n",
" return dataframe.groupby(\"fin_cours\").sum()\n"
") -> DataFrame:\n",
" return dataframe.groupby(\"fin_cours\").sum()"
],
"outputs": [],
"metadata": {
Expand Down Expand Up @@ -603,7 +603,7 @@
"cell_type": "code",
"execution_count": null,
"source": [
"# compute a next monday after a given number of weeks for the \n",
"# compute a next monday after a given number of weeks for the\n",
"# initial value (min) of the datetime index\n",
"def get_timestamp_bound(\n",
" dataframe: DataFrame,\n",
Expand Down Expand Up @@ -632,12 +632,12 @@
" mode=\"lines\",\n",
" name=column,\n",
" )\n",
" \n",
"\n",
" figure.add_trace(\n",
" scatter,\n",
" secondary_y=secondary_y,\n",
" )\n",
" \n",
"\n",
" for delimiter in delimiters:\n",
" figure.add_shape(\n",
" type=\"line\",\n",
Expand All @@ -647,9 +647,9 @@
" y1=0,\n",
" line=dict(\n",
" # color=\"Gray\",\n",
" width=1, \n",
" width=1,\n",
" dash=\"dashdot\",\n",
" ), \n",
" ),\n",
" )\n",
"\n",
" figure.add_shape(\n",
Expand All @@ -674,13 +674,13 @@
" textangle=90,\n",
" xshift=10,\n",
" )\n",
" \n",
"\n",
" figure.add_annotation(\n",
" x=delimiters[0],\n",
" y=dmax,\n",
" text=\"validation\",\n",
" showarrow=True,\n",
" yshift=-15,\n",
" yshift=-15,\n",
" )\n",
"\n",
" figure.add_annotation(\n",
Expand All @@ -694,24 +694,18 @@
" figure.update_yaxes(\n",
" rangemode=\"tozero\",\n",
" # type=\"log\",\n",
" )\n",
" )\n",
"\n",
" figure.update_xaxes(range=[dataframe.index.min(), dataframe.index.max()])\n",
" figure.update_yaxes(title_text=columns[0], secondary_y=False)\n",
" figure.update_yaxes(title_text=columns[1], secondary_y=True) \n",
" figure.update_yaxes(title_text=columns[1], secondary_y=True)\n",
" figure.update_layout(\n",
" title_text=\"Count of Buses & Classes\",\n",
" template=\"simple_white\", \n",
" legend=dict(\n",
" orientation=\"h\",\n",
" yanchor=\"bottom\",\n",
" y=1.02,\n",
" xanchor=\"right\",\n",
" x=1\n",
" ) \n",
" template=\"simple_white\",\n",
" legend=dict(orientation=\"h\", yanchor=\"bottom\", y=1.02, xanchor=\"right\", x=1),\n",
" )\n",
"\n",
" figure.show() "
" figure.show()"
],
"outputs": [],
"metadata": {}
Expand Down Expand Up @@ -837,7 +831,7 @@
" la_toussaint,\n",
" la_toussaint + one_week_timedelta,\n",
" )\n",
" \n",
"\n",
" # dataframe.drop([\"nombre_etudiant\"], axis=1, inplace=True)\n",
" return dataframe_"
],
Expand Down Expand Up @@ -980,7 +974,7 @@
" # Convert the final state to our desired output shape (batch_size, output_dim)\n",
" out = self.fc(out)\n",
"\n",
" return out\n"
" return out"
],
"outputs": [],
"metadata": {
Expand All @@ -1000,7 +994,7 @@
"cell_type": "code",
"execution_count": null,
"source": [
"# dimension (neurons) of a hidden layer \n",
"# dimension (neurons) of a hidden layer\n",
"HIDDEN_DIM = 64\n",
"\n",
"# number of hidden layers\n",
Expand All @@ -1009,7 +1003,7 @@
"# number of rows processed at the same time\n",
"BATCH_SIZE = 64\n",
"\n",
"# number of iterations during training\n",
"# number of iterations during training\n",
"EPOCHS = 100"
],
"outputs": [],
Expand Down Expand Up @@ -1046,7 +1040,7 @@
"cell_type": "code",
"execution_count": null,
"source": [
"# Helper to train the NN model\n",
"# Helper to train the NN model\n",
"class RunnerHelper:\n",
" def __init__(self, model, loss_fn, optimizer):\n",
" self.model = model\n",
Expand Down Expand Up @@ -1262,7 +1256,7 @@
"train_loader, val_loader, test_loader = to_dataloaders(\n",
" (X_train, y_train),\n",
" (X_val, y_val),\n",
" (X_test, y_test), \n",
" (X_test, y_test),\n",
" scaler,\n",
" BATCH_SIZE,\n",
")\n",
Expand Down Expand Up @@ -1564,23 +1558,24 @@
" *,\n",
" backgrounds: Optional[Union[str, Sequence[str]]],\n",
" minutes: int,\n",
") -> DataFrame: \n",
") -> DataFrame:\n",
" if not backgrounds:\n",
" return shift_time_all(dataframe, minutes=minutes)\n",
" \n",
"\n",
" dataframe_ = dataframe.copy()\n",
" backgrounds_ = [backgrounds] if isinstance(backgrounds, str) else backgrounds \n",
" backgrounds_ = [backgrounds] if isinstance(backgrounds, str) else backgrounds\n",
" delta = Timedelta(minutes, unit=\"T\")\n",
" dataframe_.reset_index(inplace=True)\n",
" for background in backgrounds_:\n",
" dataframe_.loc[dataframe_[\"filiere\"] == background, \"fin_cours\"] = (\n",
" dataframe_[\"fin_cours\"] + delta\n",
" )\n",
" \n",
"\n",
" dataframe_.set_index(dataframe_.columns[0], inplace=True)\n",
" display_dataframe(dataframe_)\n",
" return dataframe_\n",
"\n",
"\n",
"def plot_prediction_interval_with_staggings(\n",
" dataframe: DataFrame,\n",
" staggered: DataFrame,\n",
Expand Down Expand Up @@ -1661,7 +1656,9 @@
"\n",
" figure.add_trace(bar_plot, row=4, col=1)\n",
" figure.update_xaxes(showticklabels=True, row=1, col=1)\n",
" figure.update_yaxes(title_text=\"difference\", row=4, col=1, zeroline=True, zerolinecolor=\"gray\")\n",
" figure.update_yaxes(\n",
" title_text=\"difference\", row=4, col=1, zeroline=True, zerolinecolor=\"gray\"\n",
" )\n",
" figure.update_xaxes(\n",
" showticklabels=False,\n",
" visible=False,\n",
Expand All @@ -1688,13 +1685,13 @@
" backgrounds: Optional[Union[str, Sequence[str]]],\n",
" minutes: int,\n",
") -> DataFrame:\n",
" #staggered_classes = shift_time_all(classes, minutes=minutes)\n",
" # staggered_classes = shift_time_all(classes, minutes=minutes)\n",
" classes_dataset_ = shift_time(\n",
" classes,\n",
" backgrounds=backgrounds,\n",
" minutes=minutes,\n",
" )\n",
" \n",
"\n",
" staggered_classes = pre_process_by_aggregation_classes(classes_dataset_)\n",
" dataframe = merge_datasets(staggered_classes, buses)\n",
" dataframe = add_features(dataframe, holidays=True)\n",
Expand Down Expand Up @@ -1736,11 +1733,11 @@
"\n",
"SHIFT_IN_MINUTES = 45\n",
"\n",
"# available backgrounds (None mean do not filter an take 'all' of them): \n",
"# available backgrounds (None mean do not filter an take 'all' of them):\n",
"# ['ISTIC', 'DUT', 'ESIR', 'SVE', 'SPM', 'Math', 'Philo']\n",
"BACKGROUNDS = None\n",
"\n",
"# END : play \n",
"# END : play\n",
"####################\n",
"\n",
"\n",
Expand Down
Loading

0 comments on commit 5971f58

Please sign in to comment.