elaborate_scores.py 22.8 KB
Newer Older
1
2
3
#!/usr/bin/env python
import click
import sqlite3
4
import os
5
import pandas as pd
6
import yaml
7
import seaborn as sns
Matteo Paltenghi's avatar
Matteo Paltenghi committed
8
9
import matplotlib  # noqa
import matplotlib.pyplot as plt  # noqa
10
11
12
import numpy as np
import miscellanea.experiment_cern_dataset as exp_cern_ds

13
14
from datetime import datetime

15
from pathlib import Path
16
import sklearn.metrics
17
18
from miscellanea.experiment_comparison import plot_roc  # noqa
from sklearn.metrics import precision_recall_curve  # noqa
19
import sklearn
Antonin Dvorak's avatar
Antonin Dvorak committed
20
from adcern.sqlite3_backend import modify_db
21
22
23
24
25
26
27
28


# example of start and end of week
MY_START_WEEK = datetime(year=2020, month=8, day=9)
MY_END_WEEK = datetime(year=2020, month=8, day=16)

# CONSOLE APPLICATION

29
30
31
# see https://github.com/pallets/click/issues/1123
def normalize_names(name):
    return name.replace("_", "-")
Matteo Paltenghi's avatar
Matteo Paltenghi committed
32

33
@click.group(context_settings={"token_normalize_func": normalize_names})
34
35
36
def cli():
    print("Welcome in the visualization CLI.")

Matteo Paltenghi's avatar
Matteo Paltenghi committed
37

38
@cli.command()
39
@click.option('--input_folder', default="",
40
41
42
              help='''path where to look for parquet files to merge.''')
@click.option('--output_folder', default="",
              help='''path where to save combined_scores.parquet.''')
43
def merge_db(input_folder, output_folder):
44
    """Merge Parquet DB in the folder."""
45
46
47
48
    # get db files
    onlyfiles = [f for f in os.listdir(input_folder)
                 if os.path.isfile(os.path.join(input_folder, f))]
    db_files = [f for f in onlyfiles
49
50
51
                if ((f[:7] == "scores_") and (f[-8:] == ".parquet"))]
    print("db_files:")
    print(db_files)
52

53
54
55
56
57
58
59
    df_all_algos_all_windows = []
    # get scores of every algo-window
    for parquet_filename in db_files:
        print(parquet_filename)
        df_single_algo = \
            pd.read_parquet(input_folder + "/" + parquet_filename)
        df_all_algos_all_windows.append(df_single_algo)
60

61
    df_all = pd.concat(df_all_algos_all_windows, ignore_index=True)
62

63
    df_all.to_parquet(output_folder + "/combined_scores.parquet", index=False)
64
65
66


@cli.command()
67
68
69
70
71
72
73
74
75
@click.option('--hostgroup', default="",
              help='''full name of the hostgroup to extract.''')
@click.option('--input_folder', default="",
              help='''path where to find labels and combined_scores.parquet.
              ''')
@click.option('--output_folder', default="",
              help='''path where to save the png plot file.''')
@click.option('--fixed_scale', default="True",
              help='''if you want to bing your scores on a fixed scale.''')
76
def create_corr_score(hostgroup, input_folder, output_folder, fixed_scale):
77
78
    """Plot the correlations among scores."""
    print('Creating correlation scores!')
79
80
81
82
83
84
85
86
87
88
    # READ ALL SCORES
    df_all = pd.read_parquet(input_folder + "/combined_scores.parquet")

    # READ ALL LABELS
    # get the labels
    name_hostgroup = hostgroup.split("/")[-1]
    in_csv = input_folder + "/" + name_hostgroup + "_labels.csv"
    df_labels = pd.read_csv(in_csv, index_col=0)
    # melt labels
    df_labels_melt = pd.melt(df_labels.reset_index(),
89
90
                             id_vars='index',
                             value_vars=list(df_labels.columns))
91
92
    df_labels_melt["index"] = pd.to_datetime(df_labels_melt["index"])
    df_labels_melt.set_index("index", inplace=True)
93
94
    df_labels_melt.rename(columns={'value': 'label', 'variable': 'hostname'},
                          inplace=True)
95
    df_labels_melt['hostname'] = \
96
        df_labels_melt['hostname'].apply(lambda y: y + ".cern.ch")
97
98
99
100
101
102
103
104
    df_labels_melt.reset_index(drop=False, inplace=True)

    all_algos = list(df_all["algorithm"].unique())
    print("ALGOS: ", all_algos)

    HEIGHT = 3
    ROWS = len(all_algos)

105
    fig, axes = plt.subplots(ncols=ROWS, nrows=ROWS)
106
    fig.set_size_inches(ROWS * HEIGHT, ROWS * HEIGHT)
107
108
109
    fig.tight_layout(pad=3)

    # FOR EVERY ALGO -------------------------------------------------------
110
111
    for i, algo_name_row in enumerate(all_algos):
        for j, algo_name_col in enumerate(all_algos):
112
113
114
            # get current axes
            c_ax = axes[i][j]
            print("i=%i, j=%i, row=%s, col=%s" %
115
                  (i, j, algo_name_row, algo_name_col))
116

117
            # GET SCORES OF TWO ALGOs ----------------------------------------
118
119
            df_scores_algo_row = df_all[df_all["algorithm"] == algo_name_row]
            df_scores_algo_col = df_all[df_all["algorithm"] == algo_name_col]
120
121

            # COMBINE SCORES TWO ALGOS
122
            df_algos = \
123
124
125
                df_scores_algo_row.merge(
                    df_scores_algo_col,
                    on=['end_window', 'hostname', 'hostgroup'])
126
            df_algos.rename(columns={"score_x": "score_row",
127
                                     "score_y": "score_col"},
128
129
130
131
                            inplace=True)
            df_algos = df_algos[["end_window", "hostname",
                                "score_row", "score_col"]]

132
            # FIX TIME OF ALGOS - TIMECHANGE
133
134
135
136
137
138
139
            df_algos["end_window"] = df_algos["end_window"].astype('int')
            df_algos["timestamp"] = \
                pd.to_datetime(df_algos["end_window"] + 60 * 60, unit='s')
            df_algos["index"] = pd.to_datetime(df_algos["timestamp"])
            df_algos.index = df_algos["index"]
            df_algos = exp_cern_ds.change_time_for_scores(df_algos)

140
141
142
            # COMBINE - LABELS VS SCORES -------------------------------------
            df_to_evaluate = \
                df_algos.merge(df_labels_melt, on=['index', 'hostname'])
143
            df_to_evaluate.reset_index(drop=False, inplace=True)
144
145
146
            # REPLACE - DROP MIX AND EMPTY  WINDOWS
            df_to_evaluate = df_to_evaluate.replace(2, np.NaN)
            df_to_evaluate = df_to_evaluate.dropna(axis=0)
147

148
149
150
151
152
153
            # TRUTH ----------------------------------------------------------
            # plot anomalies on topo of noraml if on the upper part of the
            # matrix
            df_to_evaluate = \
                df_to_evaluate[df_to_evaluate["label"] == int(i > j)]
            if (i > j):
154
                c_ax.set_title("Anomalies")
155
                my_color = "darkorange"
156
157
            else:
                c_ax.set_title("Normal")
158
                my_color = "dodgerblue"
159
160
161
162
163
            scores_row = list(df_to_evaluate['score_row'])
            scores_col = list(df_to_evaluate['score_col'])

            sns.scatterplot(y=scores_row, x=scores_col, color=my_color,
                            ax=c_ax, marker="x")
164
165
166
167
168
            if fixed_scale == "True":
                lim_min = -2.5
                lim_max = 10
                c_ax.set_ylim((lim_min, lim_max))
                c_ax.set_xlim((lim_min, lim_max))
169
170
            c_ax.set_xlabel(algo_name_col)
            c_ax.set_ylabel(algo_name_row)
171
            # c_ax.legend()
172
173
            c_ax.grid()
    plt.savefig(output_folder + "/corr_output.png")
174
175


176
177
178
179
180
@cli.command()
@click.option('--hostgroup', default="",
              help='''full name of the hostgroup to inspect.''')
@click.option('--input_folder', default="",
              help='''path where to look for labels.''')
181
def label_presence(hostgroup, input_folder):
182
183
184
185
186
187
188
189
    """Check if labels for this hostgroup are already present."""
    name_hostgroup = hostgroup.split("/")[-1]
    in_csv = input_folder + "/" + name_hostgroup + "_labels.csv"
    df_labels = pd.read_csv(in_csv, index_col=0)
    print("We have %i labels for the hostgroup: %s."
          % (len(df_labels), hostgroup))


190
191
192
193
194
195
196
@cli.command()
@click.option('--grafana_token', default="",
              help='''path to the grafana token.''')
@click.option('--hostgroup', default="",
              help='''full name of the hostgroup to extract.''')
@click.option('--output_folder', default="",
              help='''path where to save the csv file.''')
197
def extract_annotation(grafana_token, hostgroup, output_folder):
198
    """Extract the annotation from grafana and save a csv file."""
199
200
201
202
203
204
205
206
207
    # API call to grafana
    jres = \
        exp_cern_ds.query_for_annotations(
            hostgroups=[hostgroup],
            file_path_token=grafana_token)
    # CONVERT TO DATAFRAME
    df_annotations = exp_cern_ds.convert_to_dataframe(jres)
    # SAVE TO FILE
    name_hostgroup = hostgroup.split("/")[-1]
208
    Path(output_folder).mkdir(parents=True, exist_ok=True)
209
    out_csv = \
210
211
        output_folder + "/" + name_hostgroup + "_annotations.csv"
    print("annotations csv out: ", out_csv)
212
213
    df_annotations.to_csv(out_csv, index=False)

Matteo Paltenghi's avatar
Matteo Paltenghi committed
214

215
216
217
218
219
220
221
222
223
224
225
@cli.command()
@click.option('--hostgroup', default="",
              help='''full name of the hostgroup to extract.''')
@click.option('--input_folder', default="",
              help='''path where to find annotation csv file.''')
@click.option('--analysis_file', default="",
              help='''json with info about the window size of analysis.''')
@click.option('--config_file', default="",
              help='''json with info about the granularity.''')
@click.option('--output_folder', default="",
              help='''path where to save the csv file.''')
226
def create_labels(hostgroup, input_folder, analysis_file,
227
228
229
230
231
232
233
234
235
236
                 config_file, output_folder):
    """Convert interval annotations into window labels.

    The window length is taken from the analysis_file.
    """
    # read the interval annotations
    name_hostgroup = hostgroup.split("/")[-1]
    in_csv = input_folder + "/" + name_hostgroup + "_annotations.csv"
    df_annotations = pd.read_csv(in_csv)
    # read the timing of analysis - how long is a non overlapping window
237
238
239
240
241
242
243
244
245
246
247
    # read yaml
    with open(analysis_file) as yaml_file:
        config_dict_analysis = yaml.safe_load(yaml_file)
    # read json
    # with open(analysis_file) as json_file:
    #     config_dict_analysis = json.load(json_file)
    # read yaml
    with open(config_file) as yaml_file:
        config_dict_etl = yaml.safe_load(yaml_file)
    # with open(config_file) as json_file:
    #     config_dict_etl = json.load(json_file)
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    steps_in_a_window = int(config_dict_analysis["slide_steps"])
    granularity = int(config_dict_etl["aggregate_every_n_minutes"])
    print("steps_in_a_window: ", steps_in_a_window)
    print("granularity: ", granularity)
    # convert into windows
    df_labels = \
        exp_cern_ds.count_per_interval(
            df_raw_annotations=df_annotations,
            nr_min_in_a_window=steps_in_a_window * granularity)
    out_csv = output_folder + "/" + name_hostgroup + "_labels.csv"
    print("window labels csv out: ", out_csv)
    df_labels.to_csv(out_csv)


@cli.command()
@click.option('--hostgroup', default="",
              help='''full name of the hostgroup to extract.''')
@click.option('--input_folder', default="",
              help='''path where to find labels csv file.''')
@click.option('--output_folder', default="",
              help='''path where to save the plotting image.''')
269
def visualize_labels(hostgroup, input_folder, output_folder):
270
    """Visualize label."""
271
    # read the labels
272
273
274
275
276
277
278
279
280
    name_hostgroup = hostgroup.split("/")[-1]
    in_csv = input_folder + "/" + name_hostgroup + "_labels.csv"
    df_labels = pd.read_csv(in_csv, index_col=0)
    exp_cern_ds.visualize_heatmap_annotations(df_labels, figsize=(12, 20))
    plt.savefig(output_folder + "/" + name_hostgroup + "_labels.png",
                bbox_inches='tight')


@cli.command()
281
282
283
284
285
286
287
288
@click.option('--folder_scores', default="",
              help='''the  path to the folder with the sqlite3 database named
                      scores.db
                    ''')
@click.option('--hostgroup', default="",
              help='''full name of the hostgroup to extract.''')
@click.option('--labels_folder', default="",
              help='''path where to find labels csv file.''')
289
290
@click.option('--algo_name', default="",
              help='''name of the algorithm to score.''')
291
292
@click.option('--family', default="",
              help='''either: Traditional, Deep, Ensemble.''')
Matteo Paltenghi's avatar
Matteo Paltenghi committed
293
294
295
296
297
298
299
300
@click.option('--start', default="",
              help='''start of evaluation period.
                      Note the unusual presence of underscore "_"
                      e.g. "2020-02-13_16:00:00"''')
@click.option('--end', default="",
              help='''end of evaluation period.
                      Note the unusual presence of underscore "_"
                      e.g. "2020-03-01_00:00:00"''')
301
def score_benchmark(folder_scores, hostgroup,
Matteo Paltenghi's avatar
Matteo Paltenghi committed
302
303
                   labels_folder, algo_name, family,
                   start, end):
304
    """Score the algorithm anomaly scores agains the labels."""
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    # get labels
    name_hostgroup = hostgroup.split("/")[-1]
    in_csv = labels_folder + "/" + name_hostgroup + "_labels.csv"
    df_labels = pd.read_csv(in_csv, index_col=0)
    # melt label matrix into a database
    df_labels_melt = pd.melt(df_labels.reset_index(),
                             id_vars='index',
                             value_vars=list(df_labels.columns))
    df_labels_melt.set_index("index", inplace=True)
    df_labels_melt.rename(columns={'value': 'label', 'variable': 'hostname'},
                          inplace=True)
    df_labels_melt['hostname'] = \
        df_labels_melt['hostname'].apply(lambda y: y + ".cern.ch")
    df_labels_melt.reset_index(drop=False, inplace=True)
    df_labels_melt["index"] = pd.to_datetime(df_labels_melt["index"])
    # get scores
321
322
323
324
    # df_all contains all_algos and all_windows
    df_all = pd.read_parquet(folder_scores + "/combined_scores.parquet")
    df_algo = df_all[df_all["algorithm"] == algo_name]
    df_algo["end_window"] = df_algo["end_window"].astype('int')
325
326
327
328
329
330
    df_algo["timestamp"] = \
        pd.to_datetime(df_algo["end_window"] + 60 * 60, unit='s')
    df_algo["index"] = pd.to_datetime(df_algo["timestamp"])
    df_algo.index = df_algo["index"]
    # TODO timezones handling - problems after April
    df_algo = exp_cern_ds.change_time_for_scores(df_algo)
Matteo Paltenghi's avatar
Matteo Paltenghi committed
331
332
333
334
335
336
    if start == "" or end == "":
        raise Exception("You must explicetely declare start and end of your"
                        "benchmark")
    start = start.replace("_", " ")
    end = end.replace("_", " ")
    df_algo = df_algo[start: end]
337
338
339
340
    df_algo["index"] = df_algo.index
    # compare - merge scores and labels
    df_all = df_algo.merge(df_labels_melt, on=['index', 'hostname'])
    df_all.reset_index(drop=False, inplace=True)
341
    df_all = df_all[["index", "hostname", "score", "label", "timestamp"]]
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    # create metrics
    df_to_evaluate = df_all.copy()
    # REPLACE - DROP MIX AND EMPTY  WINDOWS --------------------------------
    df_to_evaluate = df_to_evaluate.replace(2, np.NaN)
    df_to_evaluate = df_to_evaluate.dropna(axis=0)
    # TRUTH - ALL WEEKs ----------------------------------------------------
    truth = list(df_to_evaluate['label'])  # noqa
    my_guess = list(df_to_evaluate['score'])  # noqa
    # ADD WEEKs ------------------------------------------------------------
    df_to_evaluate["week"] = \
        df_to_evaluate["index"].apply(
            lambda x: exp_cern_ds.assign_week(x,
                                              dt_start_week=MY_START_WEEK,
                                              dt_end_week=MY_END_WEEK))
    # MULTIPLE ROC ---------------------------------------------------------
    weeks_available = list(df_to_evaluate["week"].unique())
    aucs_weeks = []
    # connect to the db
360
    conn_score = sqlite3.connect(labels_folder + '/week_metrics.db',
Antonin Dvorak's avatar
Antonin Dvorak committed
361
362
                                 timeout=120)
    # ensure the table is there
363
364
365
    modify_db(
            conn = conn_score, 
            query = '''CREATE TABLE IF NOT EXISTS auc
Antonin Dvorak's avatar
Antonin Dvorak committed
366
367
368
                (hostgroup text, algorithm text, family text,
                auc_score real, week_index int, end_week int,
                PRIMARY KEY
369
370
                (hostgroup, algorithm, end_week))''',
            upperbound = 10)
Antonin Dvorak's avatar
Antonin Dvorak committed
371

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    # FOR EVERY WEEK -------------------------------------------------------
    for w in sorted(weeks_available):
        print("WEEK: ", w, end=" - ")
        df_this_week = df_to_evaluate[df_to_evaluate["week"] == w]
        dt_end_week = df_this_week["timestamp"].max()
        end_week = int((dt_end_week - datetime(1970, 1, 1)).total_seconds())
        weekly_truth = list(df_this_week['label'])
        weekly_my_guess = list(df_this_week['score'])
        fpr, tpr, trs_roc = \
            sklearn.metrics.roc_curve(weekly_truth, weekly_my_guess)
        roc_auc = sklearn.metrics.auc(fpr, tpr)
        # plot_roc(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
        #             ax=ax_multiple_roc, alpha=0.2)
        # ALL ROCS in ONE PIC
        # plot_roc(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
        #             ax=dict_ax_rocs[algo_name], alpha=0.2)
        print("AUC: ", roc_auc)
        aucs_weeks.append(roc_auc)
390
391
392
393
394
        modify_db(
                conn = conn_score,
                query = '''INSERT OR IGNORE INTO auc
                    VALUES (?, ?, ?, ?, ?, ?)''',
                upperbound = 10,
395
                params = (hostgroup, algo_name, family, roc_auc, int(w), end_week))
396
397
    conn_score.close()

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    # # CUMULATIVE QUANTITIES
    # HEIGHT = 4
    # fig, axes = plt.subplots(ncols=3, nrows=1)
    # fig.set_size_inches(3 * HEIGHT, HEIGHT)
    # fig.tight_layout(pad=3)

    # ax_dist = axes[0]
    # ax_roc = axes[1]
    # ax_threshols = axes[2]

    # # DISTRIBUTION ---------------------------------------------------------

    # bins = np.linspace(min(my_guess), max(my_guess), num=100)
    # scores_for_outliers = \
    #     [s for is_anomalous, s in zip(truth, my_guess) if is_anomalous]
    # scores_for_normals = \
    #     [s for is_anomalous, s in zip(truth, my_guess) if not is_anomalous]
    # sns.distplot(scores_for_normals, bins=bins, label="Normal",
    #              ax=ax_dist, norm_hist=True, kde=False, color="dodgerblue")
    # sns.distplot(scores_for_outliers, bins=bins, label="Anomaly",
    #              ax=ax_dist, norm_hist=True, kde=False, color="darkorange")
    # ax_dist.legend()

    # # UNIQUE ROC -----------------------------------------------------------

    # fpr, tpr, trs_roc = \
    #     sklearn.metrics.roc_curve(truth, my_guess)
    # roc_auc = sklearn.metrics.auc(fpr, tpr)
    # plot_roc(fpr=fpr, tpr=tpr, roc_auc=roc_auc, ax=ax_roc)
    # ax_roc.set_title("Cumulative ROC - AUC=%.3f" % roc_auc)

    # # THRESHOLDS ---------------------------------------------------------

    # pr, rec, tr = \
    #     precision_recall_curve(y_true=truth, probas_pred=my_guess)
    # f1 = [2 * (p * r) / (p + r) for (p, r) in zip(pr, rec)]
    # exp_cern_ds.plot_vs_thresholds(metric=pr[:-1],
    #                                metric_name="Precision",
    #                                thresholds=tr, ax=ax_threshols)
    # exp_cern_ds.plot_vs_thresholds(metric=rec[:-1],
    #                                metric_name="Recall (tpr)",
    #                                thresholds=tr, ax=ax_threshols)
    # exp_cern_ds.plot_vs_thresholds(metric=fpr,
    #                                metric_name="fpr",
    #                                thresholds=trs_roc, ax=ax_threshols)
    # exp_cern_ds.plot_vs_thresholds(metric=f1[:-1],
    #                                metric_name="F1",
    #                                thresholds=tr, ax=ax_threshols)
    # ax_threshols.legend()
    # ax_threshols.set_title("Thresholds Cum. Dist.")

    # # SAVE IMAGE
    # plt.savefig(labels_folder + "/cumulative_" + algo_name + ".png",
    #             bbox_inches='tight')
452

453
454
455
456
457
458
459
460

@cli.command()
@click.option('--hostgroup', default="",
              help='''full name of the hostgroup to extract.''')
@click.option('--input_folder', default="",
              help='''path where to find week_metrics.db file.''')
@click.option('--output_folder', default="",
              help='''path where to save the plotting image.''')
461
def visualize_auc(hostgroup, input_folder, output_folder):
462
463
    """Visualize the AUC results form the selected algo."""
    # read the database of scores in a pandas DF
Antonin Dvorak's avatar
Antonin Dvorak committed
464
    conn = sqlite3.connect(input_folder + "/week_metrics.db", timeout = 120)
465
466
467
468
469
470
471
472
473
474
    df_week_auc = pd.read_sql_query(
        "SELECT * FROM auc WHERE hostgroup='{}'".format(hostgroup), conn)
    conn.close()

    # BARPLOT AUC
    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(10, 5)
    result = \
        df_week_auc.groupby(["algorithm"])['auc_score'].aggregate(np.mean)\
        .reset_index().sort_values('auc_score')
Matteo Paltenghi's avatar
Matteo Paltenghi committed
475
476
477
    sns.barplot(x="algorithm", y="auc_score",
                data=df_week_auc.rename(columns={"family": "Family"}),
                hue="Family",
478
                order=result['algorithm'], dodge=False,
479
480
481
                capsize=.15, errwidth=2.1, ax=ax)
    ax.set_title("Performance AUC-ROC averaged over the weeks", size=24)
    ax.set_ylabel("Average AUC-ROC")
Matteo Paltenghi's avatar
Matteo Paltenghi committed
482
    ax.set_ylim(0.5, 1)
483
    ax.set_xlabel("Anomaly Detection Methods")
484
    plt.xticks(rotation=90)
485

486
487
488
    ax.yaxis.label.set_size(18)
    ax.xaxis.label.set_size(18)
    plt.grid()
489
490
491
492
    plt.savefig(output_folder + "/auc_comparison_barplot.png",
                bbox_inches='tight')

    # LINEPLOT - EVOLUTION
493
    df_week_auc.drop_duplicates(inplace=True)
494
495
496
497
498
    df_week_auc_wide = df_week_auc.pivot(index="week_index",
                                         columns="algorithm",
                                         values="auc_score")
    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(10, 5)
Matteo Paltenghi's avatar
Matteo Paltenghi committed
499
    # prepare styles
500
501
502
503
504
505
506
507
508
    nr_styles_required = len(df_week_auc["algorithm"].unique())
    dash_styles = ["",
                   (4, 1.5),
                   (1, 1),
                   (3, 1, 1.5, 1),
                   (5, 1, 1, 1),
                   (5, 1, 2, 1, 2, 1),
                   (2, 2, 3, 1.5),
                   (1, 2.5, 3, 1.2),
509
510
511
512
513
514
                   (1, 2, 3, 4),
                   (6, 1),
                   (7, 1),
                   (8, 1),
                   (9, 1),
                   (10, 1)]
515
516
    while len(dash_styles) < nr_styles_required:
        dash_styles.append("")
Matteo Paltenghi's avatar
Matteo Paltenghi committed
517
    # prepare colors - depending on the family
Matteo Paltenghi's avatar
Matteo Paltenghi committed
518
    rnd_week_index = list(df_week_auc["week_index"])[0]
Matteo Paltenghi's avatar
Matteo Paltenghi committed
519
    algos = list(df_week_auc[
Matteo Paltenghi's avatar
Matteo Paltenghi committed
520
        df_week_auc["week_index"] == rnd_week_index]["algorithm"])
Matteo Paltenghi's avatar
Matteo Paltenghi committed
521
    families = list(df_week_auc[
Matteo Paltenghi's avatar
Matteo Paltenghi committed
522
        df_week_auc["week_index"] == rnd_week_index]["family"])
Matteo Paltenghi's avatar
Matteo Paltenghi committed
523
524
525
526
527
528
529
530
531
532
533
534

    def get_color(x):
        if x == "Traditional":
            return 'darkorange'
        if x == "Deep":
            return'dodgerblue'
        if x == "Ensemble":
            return'red'

    colors = {k: get_color(v) for k, v in zip(algos, families)}
    sns.lineplot(data=df_week_auc_wide, ax=ax,
                 dashes=dash_styles, palette=colors)
535
536
537
    ax.set_title("Evolution of AUC-ROC over weeks", size=24)
    ax.set_ylabel("Weekly AUC-ROC Score")
    ax.set_xlabel("Weeks evolution")
Matteo Paltenghi's avatar
Matteo Paltenghi committed
538
    ax.set_ylim(0.5, 1)
539
    ax.legend(loc='lower left')
540
541
542
    ax.yaxis.label.set_size(18)
    ax.xaxis.label.set_size(18)
    plt.grid()
543
544
    plt.savefig(output_folder + "/auc_comparison_evolution.png",
                bbox_inches='tight')
545
546


547
548
549
@cli.command()
@click.option('--input_folder', default="",
              help='''path where to remove week_metrics.db file.''')
550
def remove_old_database(input_folder):
551
552
553
554
555
556
    """Remove old database for weekly AUC results."""
    try:
        os.remove(input_folder + "/week_metrics.db")
    except Exception as e:
        print(e)

Domenico Giordano's avatar
Domenico Giordano committed
557
558
def main():
    cli()
559

560
if __name__ == '__main__':
Domenico Giordano's avatar
Domenico Giordano committed
561
    main()