From b596eaf063fa3b963e66c55256813ddd8d808f09 Mon Sep 17 00:00:00 2001 From: Thea Aarrestad <thea.aarrestad@cern.ch> Date: Thu, 7 Dec 2023 15:45:57 +0100 Subject: [PATCH] Updating to latest model/data --- part2/part2_compression.ipynb | 217 ++++++++++++++++++++-------------- setup.sh | 3 +- 2 files changed, 130 insertions(+), 90 deletions(-) diff --git a/part2/part2_compression.ipynb b/part2/part2_compression.ipynb index 31388f2..8a1776c 100644 --- a/part2/part2_compression.ipynb +++ b/part2/part2_compression.ipynb @@ -70,8 +70,11 @@ "outputs": [], "source": [ "from tensorflow.keras.models import load_model\n", + "import os\n", "\n", - "model_path = '/eos/home-t/thaarres/cms_mlatl1t_tutorial/full_model.h5'\n", + "part1_output_dir = os.environ['MLATL1T_DIR']+'/part1/part1_outputs/'\n", + "\n", + "model_path = part1_output_dir + '/model.h5'\n", "baseline_model = load_model(model_path)\n", "\n", "baseline_model.summary()" @@ -84,6 +87,49 @@ "source": [ "So we have 3 hidden layers with [64,32,32] neurons. We don't see it here, but they are all followed by an \"elu\" activation. The output is one node activated by a sigmoid activation function.\n", "\n", + "# Load the data from Part 1\n", + "\n", + "Let's also load the data from part one already now so we know what the input shape is for defining our quantized model. Afterwards we'll also further process this input before training it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7627ae9", + "metadata": {}, + "outputs": [], + "source": [ + "import awkward as ak\n", + "import pickle\n", + "\n", + "X_train = ak.from_parquet(part1_output_dir + \"/X_train_scaled.parquet\").to_numpy() \n", + "X_test = ak.from_parquet(part1_output_dir + \"/X_test_scaled.parquet\").to_numpy() \n", + "\n", + "y_train = ak.from_parquet(part1_output_dir + \"/y_train_scaled.parquet\").to_numpy()\n", + "y_test = ak.from_parquet(part1_output_dir + \"/y_test_scaled.parquet\").to_numpy()\n", + "\n", + "# In this case the test and train data is already scaled, but this is how you would laod and apply it:\n", + "#Load the scaler and parameters and apply to the data\n", + "scale = False\n", + "if scale:\n", + " file_path = part1_output_dir+'/scaler.pkl'\n", + "\n", + " with open(file_path, 'rb') as file:\n", + " scaler = pickle.load(file)\n", + "\n", + " X_train = scaler.transform(X_train)\n", + " X_test = scaler.transform(X_test);\n", + "\n", + "\n", + "print(f\"Training on {X_train.shape[0]} events, represented by {X_train.shape[1]} input features\")\n", + "print(f\"Testing on {X_test.shape[0]} events, represented by {X_test.shape[1]} input features\")" + ] + }, + { + "cell_type": "markdown", + "id": "808a79e1", + "metadata": {}, + "source": [ "## Translating to a QKeras QAT model\n", "There are two ways to translate this into a QKeras model that can be trained quantization aware, lets first do it manually:\n", "\n", @@ -106,7 +152,7 @@ "from qkeras.qlayers import QDense, QActivation\n", "from qkeras.quantizers import quantized_bits, quantized_relu\n", "\n", - "input_size=26\n", + "input_size=X_train.shape[1]\n", "\n", "# Define the input layer\n", "inputs = Input(shape=(input_size,))\n", @@ -167,7 +213,7 @@ "- ```bits```: The bitwidth, allowing you to have $2^{bits}$ unique values of each weight parameter\n", "- ```integers```: How many are integer bits, in this case zero. All 8 bits are used to represent the fractional part of the weight parameter, with no bits dedicated to representing whole numbers. This forces the value to be between -1 and 1. For DNNs this can be useful because the focus is entirely on the precision of the fraction rather than the magnitude of the number. Question: Would this also work on the output node if your algorithm is a regression of the jet mass?\n", "- ```symmetric```: should the values be symmetric around 0? In this case it doesnt have to be.\n", - "- ```alpha```: with $2^W$ unique values available, we could let them go from [-2^W, 2^W-1] like above, but we can also let them go from $[-2^W*\\alpha, (2^W-1)*\\alpha]$. ```alpha``` is a scaling of the weights. Enabling this often leads to improved performance, but it doesnt talk so nicely to hls4ml, so we recommend leaving it at 1 (or get ready for having to debug)\n", + "- ```alpha```: with $2^W$ unique values available, we could let them go from $[-2^W, 2^W-1]$ like above, but we can also let them go from $[-2^W*\\alpha, (2^W-1)*\\alpha]$. ```alpha``` is a scaling of the weights. Enabling this often leads to improved performance, but it doesnt talk so nicely to hls4ml, so we recommend leaving it at 1 (or get ready for having to debug)\n", "\n", "Having added this, QKeras will automatically apply fake quantization for us during the forward pass, accounting for the quantization error and returning a network that is optimized for the precision you plan on using in hardware.\n", "\n", @@ -188,26 +234,6 @@ "autoQuant = False\n", "\n", "if autoQuant:\n", - " # Fine grained, per-layer control\n", - " # config = {\n", - " # \"example_model_topo_fc1\": {\n", - " # \"kernel_quantizer\": \"quantized_bits(8,0,1)\",\n", - " # \"bias_quantizer\": \"quantized_bits(8,0,1)\",\n", - " # }, \n", - " # \"example_model_topo_activation1\": \"quantized_relu(8)\", \n", - "\n", - " # \"example_model_topo_fc2\": {\n", - " # \"kernel_quantizer\": \"quantized_bits(8,0,1)\",\n", - " # \"bias_quantizer\": \"quantized_bits(8,0,1)\",\n", - " # }, \n", - " # \"example_model_topo_activation2\": \"quantized_relu(8)\", \n", - " # \"example_model_topo_fc3\": {\n", - " # \"kernel_quantizer\": \"quantized_bits(8,0,1)\",\n", - " # \"bias_quantizer\": \"quantized_bits(8,0,1)\",\n", - " # }, \n", - " # example_model_topo_activation3: \"quantized_relu(8)\", \n", - " # } \n", - " # Coarse grained, per-layertype quantization\n", " config = {\n", " \"QDense\": {\n", " \"kernel_quantizer\": \"quantized_bits(bits=8, integer=0, symmetric=0, alpha=1)\",\n", @@ -296,56 +322,21 @@ { "cell_type": "code", "execution_count": null, - "id": "ce0355af", + "id": "9092b6db", "metadata": {}, "outputs": [], "source": [ - "import awkward as ak\n", - "import pickle\n", - "\n", - "path = '/eos/home-t/thaarres/cms_mlatl1t_tutorial/'\n", - "\n", - "X_train = ak.from_parquet(path + \"/X_train.parquet\").to_numpy() \n", - "X_test = ak.from_parquet(path + \"/X_test.parquet\").to_numpy() \n", - "\n", - "y_train = ak.from_parquet(path + \"/y_train.parquet\").to_numpy()\n", - "y_test = ak.from_parquet(path + \"/y_test.parquet\").to_numpy()\n", - "\n", - "# In this case the test and train data is already scaled, but this is how you would laod and apply it:\n", - "#Load the scaler and parameters and apply to the data\n", - "scale = False\n", - "if scale:\n", - " file_path = path+'/scaler.pkl'\n", - "\n", - " with open(file_path, 'rb') as file:\n", - " scaler = pickle.load(file)\n", "\n", - " X_train = scaler.transform(X_train)\n", - " X_test = scaler.transform(X_test);\n", - "\n", - "\n", - "print(f\"Training on {X_train.shape[0]} events, represented by {X_train.shape[1]} input features\")\n", - "print(f\"Testing on {X_test.shape[0]} events, represented by {X_test.shape[1]} input features\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "94d20cfb", - "metadata": {}, - "outputs": [], - "source": [ "import matplotlib.pyplot as plt\n", "\n", - "bins = 1024\n", + "bins = 4096\n", "\n", "plt.figure(figsize=(10, 6))\n", - "\n", - "plt.hist(X_train[:1000, :], bins=bins, stacked=True, label=[f'Input {i+1}' for i in range(X_train.shape[1])])\n", - "\n", + "#Input distribution, stacked per feature. This is very slow to plot, so lets look at all the features flattened later on\n", + "plt.hist(X_train, bins=bins, stacked=True, label=[f'Input {i+1}' for i in range(X_train.shape[1])]) \n", + "# plt.hist(X_train.flatten(), bins=bins, color='orangered', label='Floating point')\n", "plt.xlabel('Feature Value (standardized)')\n", "plt.ylabel('Frequency')\n", - "plt.title('Stacked Histogram of all features')\n", "plt.legend(loc='upper right', ncol=2)\n", "plt.semilogy()\n", "plt" @@ -356,7 +347,11 @@ "id": "a275844b", "metadata": {}, "source": [ - "In this case, the values seem to be <50 so lets assume 6 integer bits ($2^6=64$). The number of fractional bits will define our precision, and will affect the network performance. Let's assume 10 is sufficient (the smallest increment we can represent is $2^{-10}=0.0009765625$). To make our network adapt to this input precision, we need to \"treat\" our training and testing set with a quantizer to go from FP32 $\\rightarrow <16,6>$:" + "In this case, the values seem to be mostly <50, with a few outliers so lets assume 6 integer bits ($2^6=64$) is sufficient (the rest will get clipped). The number of fractional bits will define our precision, and will affect the network performance. Let's assume 10 is sufficient (the smallest increment we can represent is $2^{-10}=0.0009765625$).\n", + "\n", + "We can evaluate these choices by comparing the accuracy of the network to that in the previous part. \n", + "\n", + "To make our network adapt to this input precision, we need to \"treat\" our training and testing set with a quantizer to go from FP32 $\\rightarrow <16,6>$:" ] }, { @@ -385,11 +380,11 @@ "source": [ "plt.figure(figsize=(10, 6))\n", "\n", - "plt.hist(qX_train[:1000, :], bins=bins, stacked=True, label=[f'Input {i+1}' for i in range(X_train.shape[1])])\n", - "\n", - "plt.xlabel('Quantized Feature Value (standardized)')\n", + "# plt.hist(qX_train, bins=bins, stacked=True, label=[f'Input {i+1}' for i in range(X_train.shape[1])])\n", + "plt.hist(X_train.flatten(), bins=bins, color='orangered', label='Floating point')\n", + "plt.hist(qX_train.flatten(), bins=bins, color='royalblue', label='Quantized')\n", + "plt.xlabel('Feature Value (standardized)')\n", "plt.ylabel('Frequency')\n", - "plt.title('Stacked Histogram of all features')\n", "plt.legend(loc='upper right', ncol=2)\n", "plt.semilogy()\n", "plt" @@ -400,9 +395,12 @@ "id": "3e0e7438", "metadata": {}, "source": [ - "The weight distribution looks similar, but we can not really say how much we loose in performance before training with different input precisions.\n", + "The weight distribution looks similar, but we can not really say how much we lose in performance before training with different input precisions.\n", + "\n", + "## Train the network quantization aware\n", + "Phew, okay, finally time to train. For this part there are 2 things to note: you need to add a pruning callback and also you might need to adjust the learning rate (like add a learning rate decay). Also, most likely you need to increase the number of epochs.\n", "\n", - "Phew, okay, finally time to train. For this part there are 2 things to note: you need to add a pruning callback and also you might need to adjust the learning rate (like add a learning rate decay). Let's train!" + "Let's train!" ] }, { @@ -425,10 +423,10 @@ "early_stopping = EarlyStopping(monitor='val_loss', patience=5)\n", "callbacks=[early_stopping, reduce_lr, model_checkpoint, pruning_callbacks.UpdatePruningStep()]\n", "\n", - "adam = Adam(learning_rate=0.0001)\n", + "adam = Adam(learning_rate=0.001)\n", "qmodel.compile(optimizer=adam, loss=['binary_crossentropy'], metrics=['accuracy'])\n", "\n", - "qmodel.fit(qX_train, y_train, batch_size=2048, epochs=50,validation_split=0.20, shuffle=True,callbacks=callbacks,verbose=1) \n", + "qmodel.fit(qX_train, y_train, batch_size=4096, epochs=60,validation_split=0.20, shuffle=True,callbacks=callbacks,verbose=1) \n", "qmodel = strip_pruning(qmodel)\n", "qmodel.save('qtopo_model.h5')" ] @@ -438,6 +436,8 @@ "id": "75409ec1", "metadata": {}, "source": [ + "## Comparing to he floating point model\n", + "\n", "Before checking and comparing the accuracy, lets look at the weights and see if they look quantized and pruned:" ] }, @@ -484,7 +484,11 @@ "id": "cdd44876", "metadata": {}, "source": [ - "This looks like expected! Now, lets compare the performance to that of the floating point model:" + "This looks quantized and pruned indeed! Now, lets compare the performance to that of the floating point model. \n", + "\n", + "We are not so interested in false positive rate (FPR) and more interested in the absolute L1 rate, so lets convert it. We will Zoom into the region $<100$ kHz for obvious reasons, which means we are working at a very low FPR. \n", + "\n", + "Ealuating the performane at such high thresholds will require a lot of stiatistics, which luckily we have:" ] }, { @@ -568,7 +572,7 @@ "\n", "<img src=\"https://gitlab.cern.ch/fastmachinelearning/cms_mlatl1t_tutorial/-/raw/master/part2/images/hls4ml_logo.png?ref_type=heads\" width=\"400\"/>\n", "\n", - "Time to translate this model into HLS (which we will integrate in the emulator) and use to generate the vhdl to be integrated in the trigger firmware.\n", + "Time to translate this model into HLS (which we will integrate in the emulator) and use to generate the vhdl to be integrated in the trigger firmware. We will use the Python library hls4ml for that ([here](https://github.com/fastmachinelearning/hls4ml-tutorial/tree/main) is the hls4ml tutorial).\n", "hls4ml seamlessly talks to QKeras, making our jobs way easier for us, but there is still some work for us to do to make sure we get good hardware model accuracy. Lets start!\n", "There are a few things I already know in advance and would like my model to include:\n", "- Be execuded fully parallel (=unrolled) to reach the lowest possible latency. We set the ReuseFactor=1 and Strategy=Latency\n", @@ -580,7 +584,7 @@ "\n", "<img src=\"https://gitlab.cern.ch/fastmachinelearning/cms_mlatl1t_tutorial/-/raw/master/part2/images/hls4ml_precisions.png?ref_type=heads\" width=\"400\"/>\n", "\n", - "Whereas the $weight$ and $bias$ is set to its optimal value from the QKeras model, the accumulator $accum$ and $result$ is set to some default value that might not be optimal for a given model and might need tuning. Let's do a first attemt and compare the ROC curves:" + "Whereas the $weight$ and $bias$ is set to its optimal value from the QKeras model, the accumulator $accum$ and $result$ is set to some default value that might not be optimal for a given model and might need tuning. Let's do a first attemt:" ] }, { @@ -629,8 +633,8 @@ " project_name='L1TMLDemo_v1', \n", " part='xcu250-figd2104-2L-e', #Target FPGA, ideally you would use VU9P and VU13P that we use in L1T but they are not installed at lxplus, this one is close enought for this\n", " clock_period=2.5, # Target frequency 1/2.5ns = 400 MHz\n", - " input_data_tb='qX_test.npy', # For co-simulation\n", - " output_data_tb='qy_test.npy',# For co-simulation\n", + "# input_data_tb='qX_test.npy', # For co-simulation\n", + "# output_data_tb='qy_test.npy',# For co-simulation\n", ")\n", "hls_model.compile()" ] @@ -661,7 +665,12 @@ "Here you can see that the precision is what we set it to be in QKeras as well as what we set manually in the config. One thing to note is the different definitions used in QKeras and in ap_fixed:\n", "- ```quantized_bits(8,0) -> ap_fixed<8,1>```\n", "- ```quantized_relu(8,0) -> ap_ufixed<8,0>```\n", - "Also you can see that the defualt value for result/accu is set to $16,6$. This can also be tuned to more optimal values." + "Also you can see that the defualt value for result/accu is set to $16,6$. This can also be tuned to more optimal values.\n", + "\n", + "## Validate the firmware model accuracy\n", + "\n", + "#et's also run predict on the C++ implementation of our model and make sure it's the ~same as for the QKeras model.\n", + "This is very slow for the C++ implementation of our model, but we need a lot of statistics to probe the low rate region. Keep reading while you wait :)!\n" ] }, { @@ -671,13 +680,11 @@ "metadata": {}, "outputs": [], "source": [ - "# Let's also run predict on the C++ implementation of our model and make sure it's the ~same as for the QKeras model:\n", - "# This is very slow for the C++ implementation of our model, so lets only do 1000 events for this\n", "y_hls = hls_model.predict(np.ascontiguousarray(qX_test))\n", "\n", - "print(f\"Truth labels:\\n {y_test[17:27]}\")\n", - "print(f\"Qkeras prediction:\\n {qy_pred[17:27]}\")\n", - "print(f\"HLS prediction:\\n {y_hls[17:27]}\")" + "print(f\"Truth labels:\\n {y_test[17:27]}\\n\")\n", + "print(f\"Qkeras prediction:\\n {qy_pred[17:27]}\\n\")\n", + "print(f\"HLS prediction:\\n {y_hls[17:27]}\\n\")" ] }, { @@ -688,7 +695,7 @@ "outputs": [], "source": [ "# Lets plot it!\n", - "hlsfpr, hlstpr, hlsthr = roc_curve(y_test, y_hls, pos_label=None, sample_weight=None, drop_intermediate=True)\n", + "hlsfpr, hlstpr, hlsthr = roc_curve(y_test, y_hls, pos_label=1, sample_weight=None, drop_intermediate=True)\n", "hlsfpr *= totalMinBiasRate()\n", "hlsroc_auc = roc_auc_score(y_test, y_hls)\n", "\n", @@ -713,7 +720,13 @@ "id": "00985383", "metadata": {}, "source": [ - "Oh! That was easier than expected. If you see the accuracies differing significantly, it's a good idea to look into accumulator and reult precisions. Also with tools like $Trace$ and $Profiling$ that you can learn from in the [official hls4ml tutorial](https://github.com/fastmachinelearning/hls4ml-tutorial/blob/main/part2_advanced_config.ipynb) can be helpful! In this case, it doesnt seem like it's necessary. Now let's build it! Lets run C-synthesis (C++ to register-transfer level), Vivado logic synthesis (gate level representation) and co-simulation (send test vectors, do an exhaustive functional test of the implemented logic)" + "Oh! That was easier than expected. If you see the accuracies differing significantly, it's a good idea to look into accumulator and reult precisions. Also with tools like $Trace$ and $Profiling$ that you can learn from in the [official hls4ml tutorial](https://github.com/fastmachinelearning/hls4ml-tutorial/blob/main/part2_advanced_config.ipynb) can be helpful! In this case, it doesnt seem like it's necessary. \n", + "\n", + "## Synthesise!\n", + "\n", + "Now let's build it! Lets run C-synthesis (C++ to register-transfer level) and Vivado logic synthesis (gate level representation). We will not do co-simulation (send test vectors, do an exhaustive functional test of the implemented logic), but this can be a good idea if you are using CNNs and the $io_stream$ io. \n", + "\n", + "Let's run!" ] }, { @@ -723,7 +736,6 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"Running synthesis!\")\n", "report = hls_model.build(csim=False, synth=True, vsynth=True, cosim=False)" ] }, @@ -732,7 +744,11 @@ "id": "3893c6cd", "metadata": {}, "source": [ - "Now, lets, look at the reports! The latency can be extracted from the C-synthesis report, and validated from the co-simulation report (where actual data is sent through the logic. The resource consumption can be extracted from the implementation report (Vivado logic synthesis) and is more accurate then what is quoted in the C-synthesis report:" + "Now, lets, look at the reports! The latency can be extracted from the C-synthesis report, and validated from the co-simulation report (where actual data is sent through the logic. \n", + "\n", + "The resource consumption can be extracted from the implementation report (Vivado logic synthesis) and is more accurate then what is quoted in the C-synthesis report. \n", + "\n", + "In this case we did not run co-simulation (this mainly because important when using CNNs and io_stream), but lets print the rest:" ] }, { @@ -744,7 +760,7 @@ "source": [ "print(\"\\nC synthesis report (latency estimate):\")\n", "print_dict(report[\"CSynthesisReport\"])\n", - "#print_dict(report[\"CosimReport\"]) Not working due to missing libc header sys/cdefs.h :(?\n", + "#print_dict(report[\"CosimReport\"]) # If also running co-sim\n", "print(\"\\nVivado synthesis report (resource estimates):\")\n", "print_dict(report[\"VivadoSynthReport\"])" ] @@ -754,8 +770,31 @@ "id": "72c1723a", "metadata": {}, "source": [ - "A latency of $2.5\\cdot16=40$ ns, that is not bad! Also, the network is using very little resources: 8k out of 1728k LUTs, 15 out of 12k DSPs. This is <1% of the total available resources. We have a set of HLS files that will be integrated into the CMSSW emulator (```L1TMLDemo_v1/firmware/```) and VHDL that will be integrated into the mGT firmware (```L1TMLDemo_v1/myproject_prj/solution1/impl/vhdl/```). That's next!" + "A latency of $2.5\\cdot15=37.5$ ns, that is not bad! \n", + "\n", + "Also, the network is using very little resources: 5k out of 1728k LUTs, 26 out of 12k DSPs. This is <1% of the total available resources. We have a set of HLS files that will be integrated into the CMSSW emulator (```L1TMLDemo_v1/firmware/```) and VHDL that will be integrated into the mGT firmware (```L1TMLDemo_v1/myproject_prj/solution1/impl/vhdl/```). That's next!\n", + "\n", + "If you did not finish synthesising before the start of the next exercise, you can copy an already synthesised project from here:" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80d4fe1d", + "metadata": {}, + "outputs": [], + "source": [ + "# ! cp /eos/home-t/thaarres/cms_mlatl1t_tutorial/L1TMLDemo_v1.tar.gz\n", + "# ! tar -xzvf " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d954d26", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/setup.sh b/setup.sh index 3ae9b69..e7212b9 100644 --- a/setup.sh +++ b/setup.sh @@ -28,4 +28,5 @@ cd $SCRIPT_DIR # put the HLS tools on the PATH echo "ML@L1T Setup: prepending $SCRIPT_DIR/bin to PATH" -export PATH=$SCRIPT_DIR/bin:$PATH \ No newline at end of file +export PATH=$SCRIPT_DIR/bin:$PATH +export MLATL1T_DIR=$SCRIPT_DIR \ No newline at end of file -- GitLab