Skip to content

Commit

Permalink
add nsteps and maxlags to arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
annacprice committed Sep 6, 2021
1 parent e1acb34 commit 42b2255
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
10 changes: 5 additions & 5 deletions covate/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import matplotlib.pyplot as plt
from .utils import appendline, pairwise, getdate, getenddate

def buildmodel(timeseries, lineagelist, regionlist, enddate, output, validate):
def buildmodel(timeseries, lineagelist, regionlist, enddate, output, maxlags, nsteps, validate):
""" Run stats tests for each lineage and select model and parameters"""

maxlag=14
nsteps=14
maxlag=int(maxlags)
nsteps=int(nsteps)
alpha = 0.05

for lineage in lineagelist:
Expand Down Expand Up @@ -273,7 +273,7 @@ def vecerrcorr(X_train, lineage, VECMdeterm, lag, coint_count, regionlist, nstep
pred = (pd.DataFrame(forecast.round(0), columns=X_train.columns, index=idx))

# cast negative predictions to zero
#pred[pred<0] = 0
pred[pred<0] = 0

path = os.path.join(output, str(getenddate(enddate)), lineage, 'prediction')

Expand Down Expand Up @@ -316,7 +316,7 @@ def vecerrcorrvalid(X_train, X_test, lineage, VECMdeterm, lag, coint_count, regi
pred = (pd.DataFrame(forecast.round(0), index=X_test.index, columns=X_test.columns))

# cast negative predictions to 0
#pred[pred<0] = 0
pred[pred<0] = 0

path = os.path.join(output, str(getenddate(enddate)), lineage, 'validation')

Expand Down
9 changes: 5 additions & 4 deletions covate/build_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dateutil.relativedelta import relativedelta
from .utils import getdate, getenddate, createoutputdir

def buildseries(metadata, regions, adm, lineagetype, timeperiod, enddate, output, validate):
def buildseries(metadata, regions, adm, lineagetype, timeperiod, enddate, output, nsteps, validate):
""" Build the time series for lineages common to specified regions"""

# load metadata and index by date
Expand All @@ -17,7 +17,7 @@ def buildseries(metadata, regions, adm, lineagetype, timeperiod, enddate, output
df[adm] = df[adm].astype(str)

# select time period
df, enddate = gettimeperiod(df, timeperiod, enddate, validate)
df, enddate = gettimeperiod(df, timeperiod, enddate, nsteps, validate)

# get region list
region_list = [str(region) for region in regions.split(', ')]
Expand Down Expand Up @@ -71,7 +71,7 @@ def buildseries(metadata, regions, adm, lineagetype, timeperiod, enddate, output
return countbydate, lineagecommon, region_list, enddate


def gettimeperiod(dataframe, timeperiod, enddate, validate):
def gettimeperiod(dataframe, timeperiod, enddate, nsteps, validate):
"""Extract time period from metadata specified by --time-period"""

# if enddate is not specified, get the most recent date in metadata and -7 days
Expand All @@ -84,7 +84,8 @@ def gettimeperiod(dataframe, timeperiod, enddate, validate):
if not validate:
startdate = enddate - relativedelta(weeks=+int(timeperiod))
else:
startdate = enddate - relativedelta(weeks=+int(timeperiod)+2)
startdate = enddate - relativedelta(weeks=+int(timeperiod))
startdate = startdate - relativedelta(days=+int(nsteps))

# get range of dates
dataframe = dataframe.sort_index().loc[str(startdate):str(enddate)]
Expand Down
12 changes: 8 additions & 4 deletions covate/covate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,24 @@ def main():
help="Select end date to take from metadata. Format: d/m/Y")
parser.add_argument("-v", "--validate", dest="validate", type=bool, required=False, default="True",
help="Run validation forecast. True or False")
parser.add_argument("-m", "--max-lags", dest="maxlags", required=False, default="14",
help="Maximum number of lags to investigate")
parser.add_argument("-n", "--n-steps", dest="nsteps", required=False, default="14",
help="Number of days to predict")
args = parser.parse_args()

# build the time series
countbydate, lineagecommon, region_list, enddate = buildseries(args.metadata, args.regions, args.adm, args.lineagetype, args.timeperiod, args.enddate, args.output, False)
countbydate, lineagecommon, region_list, enddate = buildseries(args.metadata, args.regions, args.adm, args.lineagetype, args.timeperiod, args.enddate, args.output, args.nsteps, False)

# build the model
buildmodel(countbydate, lineagecommon, region_list, enddate, args.output, False)
buildmodel(countbydate, lineagecommon, region_list, enddate, args.output, args.maxlags, args.nsteps, False)

# if validation forecast selected, run again
if args.validate:

countbydate, lineagecommon, region_list, enddate = buildseries(args.metadata, args.regions, args.adm, args.lineagetype, args.timeperiod, args.enddate, args.output, args.validate)
countbydate, lineagecommon, region_list, enddate = buildseries(args.metadata, args.regions, args.adm, args.lineagetype, args.timeperiod, args.enddate, args.output, args.nsteps, args.validate)

buildmodel(countbydate, lineagecommon, region_list, enddate, args.output, args.validate)
buildmodel(countbydate, lineagecommon, region_list, enddate, args.output, args.maxlags, args.nsteps, args.validate)

if __name__ == '__main__':
main()

0 comments on commit 42b2255

Please sign in to comment.