Skip to content
Snippets Groups Projects
Commit 76ba3b03 authored by Thea Aarrestad's avatar Thea Aarrestad
Browse files

changing figure definition for Git rendering

parent 08827751
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id:d60b2756 tags:
``` python
#Install some dependencies
%pip install pyarrow hls4ml pyparser
```
%% Cell type:markdown id:bd9c5bc4 tags:
# Quantization aware training with QKeras
Quantization is a powerful way to reduce model memory and resource consumption. In this tutorial, we will use the libary QKeras to perform quantization aware training (QAT).
In contrast to in Keras, where models are trained using floating point precision, QKeras quantizes each of the model weights and activation functions during training, allowing the network to adapt to the numerical precision that will eventually be used on hardware.
During the forward pass of the network, each floating point weight is put into one of $2^{bitwidth}$ buckets. Which one it goes into is defined through rounding and clipping schemes.
Below you can see an example of a tensor with a (symmetric) dynamic range of $x_{f}$ $[-amax, amax]$ mapped through quantization to a an 8 bit integer, $2^8=256$ discrete values in the interval $[-128, 127]$ (32-bit floating-point can represent ~4B numbers in the interval $[-3.4e38, 3.40e38]$).
<img src="images/8-bit-signed-integer-quantization.png" width="800"/>
<img src="https://gitlab.cern.ch/fastmachinelearning/cms_mlatl1t_tutorial/-/raw/master/part2/images/8-bit-signed-integer-quantization.png?ref_type=heads" width="800"/>
Quantization of floating point numbers can be achieved using the quantization operation
$$x_{q} = Clip(Round(x_{f}/scale))$$
where $x_{q}$ is the quantized digit and $x_{f}$ is the floating point digit. $Round$ is a function that applies some rounding scheme to each number and $Clip$ is a function that clips outliers that fall outside the $[-128, 127]$ interval. The $scale$ parameter is obtained by dividing the float-point dynamic-range into 256 equal parts.
On FPGA, we do not use int8 quantization, but fixed-point quantization, bu the idea is similar. Fixed-point representation is a way to express fractions with integers and offers more control over precision and range. We can split the $W$-bits making up an integer (in our case $W=8$) to represent the integer part of a number and the fractional part of the number. We usually reserve 1-bit representing the sign of the digit. The radix splits the remaining $W-1$ bits to $I$ most significant bits representing the integer value and $F$ least significant bits representing the fraction. We write this as $<W,I>$, where $F=W-1-I$. Here is an example for an unsigned $<8,3>$:
<img src="images/fixedpoint.png" width="400"/>
<img src="https://gitlab.cern.ch/fastmachinelearning/cms_mlatl1t_tutorial/-/raw/master/part2/images/fixedpoint.png?ref_type=heads" width="400"/>
This fixed point number corresponds to $2^4\cdot0+2^3\cdot0+2^2\cdot0+2^1\cdot1+2^0\cdot0+2^{-1}\cdot1+2^{-2}\cdot1+2^{-3}\cdot0=2.75$.
The choice of $I$ and $F$ has to be derived as a trade-off between representation range and precision, where $I$ controls the range and $F$ the precision.
In the following we will use a bitwidth of 8 and 0 integer bits. Not considering the sign bit, this means that the smallest number you can represent (the precision) and the largest number (the range) is:
$$ \rm{Precision}= \frac{1}{2^{F}}= \frac{1}{2^8} = 0.00390625$$
$$\rm{Range}= [-2^0,-2^0-1]=[-1,0] $$
With zero integer bits the largest number you can represent is just below (but not including) 1. For weights in deep neural networks, being constrained to be less than 1 is often a reasonable assumtion.
What QKeras (and other QAT libraries) do, is to include the quantization error during the training, in the following way:
- "Fake quantize" the floating-point weights and activations during the forward pass: quantize the weights and use them for the layer operations
- Immediately un-quantize the parameters so the rest of the computations take place in floating-point
- During the backward pass, the gradient of the weights is used to update the floating point weight
- The quantization operation gradient (zero or undefined) is handled by passing the gradient through as is ("straight through estimator")
## Inspect the original model
In the following we will use the QKeras library to add quantizers to our model. First, let's load the baseline model and remind ourselves what the architecture looks like:
%% Cell type:code id:0e6c684c tags:
``` python
from tensorflow.keras.models import load_model
model_path = '/eos/home-t/thaarres/cms_mlatl1t_tutorial/full_model.h5'
baseline_model = load_model(model_path)
baseline_model.summary()
```
%% Cell type:markdown id:eb84f91a tags:
So we have 3 hidden layers with [64,32,32] neurons. We don't see it here, but they are all followed by an "elu" activation. The output is one node activated by a sigmoid activation function.
## Translating to a QKeras QAT model
There are two ways to translate this into a QKeras model that can be trained quantization aware, lets first do it manually:
### Manual QKeras model definition:
%% Cell type:code id:d5073f61 tags:
``` python
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l1
from tensorflow.keras.layers import Activation
from qkeras.qlayers import QDense, QActivation
from qkeras.quantizers import quantized_bits, quantized_relu
input_size=26
# Define the input layer
inputs = Input(shape=(input_size,))
# Define the three hidden layers and output layer
hidden1 = QDense(
64,
name='qd1',
kernel_quantizer=quantized_bits(bits=8, integer=0, symmetric=0, alpha=1),
bias_quantizer=quantized_bits(bits=8, integer=0, symmetric=0, alpha=1),
kernel_initializer='lecun_uniform',
kernel_regularizer=l1(0.0001),
) (inputs)
hidden1 = QActivation(activation=quantized_relu(8), name='qrelu1')(hidden1)
hidden2 = QDense(
32,
name='qd2',
kernel_quantizer=quantized_bits(bits=8, integer=0, symmetric=0, alpha=1),
bias_quantizer=quantized_bits(bits=8, integer=0, symmetric=0, alpha=1),
kernel_initializer='lecun_uniform',
kernel_regularizer=l1(0.0001),
) (hidden1)
hidden2 = QActivation(activation=quantized_relu(8), name='qrelu2')(hidden2)
hidden3 = QDense(
32,
name='qd3',
kernel_quantizer=quantized_bits(bits=8, integer=0, symmetric=0, alpha=1),
bias_quantizer=quantized_bits(bits=8, integer=0, symmetric=0, alpha=1),
kernel_initializer='lecun_uniform',
kernel_regularizer=l1(0.0001),
) (hidden2)
hidden3 = QActivation(activation=quantized_relu(8), name='qrelu3')(hidden3)
# Define the output layer with a single node, let's be careful with quantizing this one and be a bit more generous
# Some prefer to leave this a Keras Dense layer, but then it requires more manual tuning in the hs4ml part
logits = QDense(1,
name='logits',
kernel_quantizer=quantized_bits(bits=13, integer=0, symmetric=0, alpha=1),
bias_quantizer=quantized_bits(bits=13, integer=0, symmetric=0, alpha=1),
kernel_initializer='lecun_uniform',
kernel_regularizer=l1(0.0001),
) (hidden3)
output = Activation(activation='sigmoid', name='output')(logits)
# Create the model
qmodel = Model(inputs=inputs, outputs=output)
# Model summary
qmodel.summary()
```
%% Cell type:markdown id:2c429c55 tags:
Wait! What is going on here?
The magic happens in ```quantized_bits``` (see implementation [here](https://github.com/google/qkeras/blob/master/qkeras/quantizers.py#L1245)), where the parameters are the following:
- ```bits```: The bitwidth, allowing you to have $2^{bits}$ unique values of each weight parameter
- ```integers```: How many are integer bits, in this case zero. All 8 bits are used to represent the fractional part of the weight parameter, with no bits dedicated to representing whole numbers. This forces the value to be between -1 and 1. For DNNs this can be useful because the focus is entirely on the precision of the fraction rather than the magnitude of the number. Question: Would this also work on the output node if your algorithm is a regression of the jet mass?
- ```symmetric```: should the values be symmetric around 0? In this case it doesnt have to be.
- ```alpha```: with $2^W$ unique values available, we could let them go from [-2^W, 2^W-1] like above, but we can also let them go from $[-2^W*\alpha, (2^W-1)*\alpha]$. ```alpha``` is a scaling of the weights. Enabling this often leads to improved performance, but it doesnt talk so nicely to hls4ml, so we recommend leaving it at 1 (or get ready for having to debug)
Having added this, QKeras will automatically apply fake quantization for us during the forward pass, accounting for the quantization error and returning a network that is optimized for the precision you plan on using in hardware.
Another thing to notice is that we leave the sigmoid and the final output logit unquantized. This is because this is were we want the values to be very accurate, and it is not going to save us a lot of resources quantizing it.
### Automatic model quantization through config
You can also set the quantization for the full model using a model configuration. Sometimes this can be sater if you're using the same quantizer for all layers of the same type. You don't have to use this for this tutorial, we already have a model, but we will leave it here as an example:
%% Cell type:code id:1b138d1f tags:
``` python
autoQuant = False
if autoQuant:
# Fine grained, per-layer control
# config = {
# "example_model_topo_fc1": {
# "kernel_quantizer": "quantized_bits(8,0,1)",
# "bias_quantizer": "quantized_bits(8,0,1)",
# },
# "example_model_topo_activation1": "quantized_relu(8)",
# "example_model_topo_fc2": {
# "kernel_quantizer": "quantized_bits(8,0,1)",
# "bias_quantizer": "quantized_bits(8,0,1)",
# },
# "example_model_topo_activation2": "quantized_relu(8)",
# "example_model_topo_fc3": {
# "kernel_quantizer": "quantized_bits(8,0,1)",
# "bias_quantizer": "quantized_bits(8,0,1)",
# },
# example_model_topo_activation3: "quantized_relu(8)",
# }
# Coarse grained, per-layertype quantization
config = {
"QDense": {
"kernel_quantizer": "quantized_bits(bits=8, integer=0, symmetric=0, alpha=1)",
"bias_quantizer": "quantized_bits(bits=8, integer=0, symmetric=0, alpha=1)",
},
"QActivation": { "relu": "quantized_relu(8)" }
}
from qkeras.utils import model_quantize
qmodel = model_quantize(model, config, 4, transfer_weights=True)
for layer in qmodel.layers:
if hasattr(layer, "kernel_quantizer"):
print(layer.name, "kernel:", str(layer.kernel_quantizer_internal), "bias:", str(layer.bias_quantizer_internal))
elif hasattr(layer, "quantizer"):
print(layer.name, "quantizer:", str(layer.quantizer))
print()
qmodel.summary()
```
%% Cell type:markdown id:6947eda4 tags:
But be careful that activation functions like softmax/sigmoid and perhaps logit layers you want to keep at full presision doesn't get quantized!
%% Cell type:markdown id:e59eea22 tags:
## But how many bits?
So now we know how to quantize our models, but how do we know wich precision to choose?
Finding the best number of bits and integer bits to use is non-trivial, and there are two ways we recommend:
- The easiest strategy is to scan over the possible bit widths from binary up to some maximum value and choose the smallest one that still has acceptable accuracy, and this is what we often do.
Code for how to do this can be found [here](https://github.com/thesps/keras-training/blob/qkeras/train/train_scan_models.py#L16), and is illustrated below.
For binary and ternary quantization, we use the special ```binary(alpha=1.0)(x)``` and ```ternary(alpha=1.0)(x)``` quantizers.
<img src="images/quant_scan.png" width="400"/>
- Another thing you can do is to use our library for automatic quantization, [AutoQKeras](https://github.com/google/qkeras/blob/master/notebook/AutoQKeras.ipynb), to find the optimal quantization for each layer. This runs hyperparameter optimisation over quantizers/nodes/filters simultenously. An example can be found at the end of [this notebook](https://github.com/fastmachinelearning/hls4ml-tutorial/blob/main/part6_cnns.ipynb) "Bonus exercise: Automatic quantization with AutoQKeras".
%% Cell type:markdown id:17da0954 tags:
## Pruning
Besides reducing the numerical precision of all the weights, biases and activations, I also want to remove neurons and synapses that do not contribute much to the network overall decision. We do that by pruning, let's remove 50\% of the weights (spasity=0.5):
%% Cell type:code id:195fe6ae tags:
``` python
from tensorflow_model_optimization.python.core.sparsity.keras import prune, pruning_schedule
from tensorflow_model_optimization.sparsity.keras import strip_pruning
# The training step is one gradient update, or epochs*N_samples/batchsize
pruning_params = {"pruning_schedule": pruning_schedule.ConstantSparsity(0.5, begin_step=6000, frequency=10)}
qmodel = prune.prune_low_magnitude(qmodel, **pruning_params)
```
%% Cell type:markdown id:35b1494e tags:
## Defining the data input type
Great, we now have our model ready to be trained! There is one last important thing we have to think about and that is the *precision of the input*. In the L1T, all of the inputs are quantized. For instance, the precision used for the GT is listed [here](https://github.com/cms-l1-globaltrigger/mp7_ugt_legacy/blob/master/doc/scales_inputs_2_ugt/pdf/scales_inputs_2_ugt.pdf).
Ideally, when you train your network, you use the hardware values that the algorithm will actually receive when running inference in the trigger.
We saw, however, that the inputs were all scaled to have a mean of zero and variance of one in the previous exercise. That means that the new optimal precision for the inputs have changes and you need to define what the precision will be. Here we will do it by inspection and intuition, and use the same precision for all of the input features. Let's now load, scale the data and look at the input value distribution:
%% Cell type:code id:ce0355af tags:
``` python
import awkward as ak
import pickle
path = '/eos/home-t/thaarres/cms_mlatl1t_tutorial/'
X_train = ak.from_parquet(path + "/X_train.parquet").to_numpy()
X_test = ak.from_parquet(path + "/X_test.parquet").to_numpy()
y_train = ak.from_parquet(path + "/y_train.parquet").to_numpy()
y_test = ak.from_parquet(path + "/y_test.parquet").to_numpy()
# In this case the test and train data is already scaled, but this is how you would laod and apply it:
#Load the scaler and parameters and apply to the data
scale = False
if scale:
file_path = path+'/scaler.pkl'
with open(file_path, 'rb') as file:
scaler = pickle.load(file)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test);
print(f"Training on {X_train.shape[0]} events, represented by {X_train.shape[1]} input features")
print(f"Testing on {X_test.shape[0]} events, represented by {X_test.shape[1]} input features")
```
%% Cell type:code id:94d20cfb tags:
``` python
import matplotlib.pyplot as plt
bins = 1024
plt.figure(figsize=(10, 6))
plt.hist(X_train[:1000, :], bins=bins, stacked=True, label=[f'Input {i+1}' for i in range(X_train.shape[1])])
plt.xlabel('Feature Value (standardized)')
plt.ylabel('Frequency')
plt.title('Stacked Histogram of all features')
plt.legend(loc='upper right', ncol=2)
plt.semilogy()
plt
```
%% Cell type:markdown id:a275844b tags:
In this case, the values seem to be <50 so lets assume 6 integer bits ($2^6=64$). The number of fractional bits will define our precision, and will affect the network performance. Let's assume 10 is sufficient (the smallest increment we can represent is $2^{-10}=0.0009765625$). To make our network adapt to this input precision, we need to "treat" our training and testing set with a quantizer to go from FP32 $\rightarrow <16,6>$:
%% Cell type:code id:dc7be5f8 tags:
``` python
import numpy as np
input_quantizer = quantized_bits(bits=16, integer=6, symmetric=0, alpha=1)
qX_train = input_quantizer(X_train.astype(np.float32)).numpy()
qX_test = input_quantizer(X_test.astype(np.float32)).numpy()
# Save the quantized test data and labels to a numpy file, such that it can be used to test the firmware
np.save('qX_test.npy', qX_test)
np.save('qy_test.npy', y_test)
```
%% Cell type:code id:3de19ae6 tags:
``` python
plt.figure(figsize=(10, 6))
plt.hist(qX_train[:1000, :], bins=bins, stacked=True, label=[f'Input {i+1}' for i in range(X_train.shape[1])])
plt.xlabel('Quantized Feature Value (standardized)')
plt.ylabel('Frequency')
plt.title('Stacked Histogram of all features')
plt.legend(loc='upper right', ncol=2)
plt.semilogy()
plt
```
%% Cell type:markdown id:3e0e7438 tags:
The weight distribution looks similar, but we can not really say how much we loose in performance before training with different input precisions.
Phew, okay, finally time to train. For this part there are 2 things to note: you need to add a pruning callback and also you might need to adjust the learning rate (like add a learning rate decay). Let's train!
%% Cell type:code id:9556e6bf tags:
``` python
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
model_checkpoint = ModelCheckpoint('model_best_checkpoint.h5', save_best_only=True, monitor='val_loss')
# This might result in returning a not fully pruned model, but that's okay!
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3)
early_stopping = EarlyStopping(monitor='val_loss', patience=5)
callbacks=[early_stopping, reduce_lr, model_checkpoint, pruning_callbacks.UpdatePruningStep()]
adam = Adam(learning_rate=0.0001)
qmodel.compile(optimizer=adam, loss=['binary_crossentropy'], metrics=['accuracy'])
qmodel.fit(qX_train, y_train, batch_size=2048, epochs=50,validation_split=0.20, shuffle=True,callbacks=callbacks,verbose=1)
qmodel = strip_pruning(qmodel)
qmodel.save('qtopo_model.h5')
```
%% Cell type:markdown id:75409ec1 tags:
Before checking and comparing the accuracy, lets look at the weights and see if they look quantized and pruned:
%% Cell type:code id:402b4267 tags:
``` python
colors = ['#7b3294','#c2a5cf','#a6dba0','#008837']
# TAKE EVERY OPPORTUNITY TO ADVERTISE COLORBLIND SAFE PLOTS :)
allWeightsByLayer = {}
for layer in qmodel.layers:
layername = layer._name
if len(layer.get_weights())<1:
continue
weights=layer.weights[0].numpy().flatten()
allWeightsByLayer[layername] = weights
print('Layer {}: % of zeros = {}'.format(layername,np.sum(weights==0)/np.size(weights)))
labelsW = []
histosW = []
for key in reversed(sorted(allWeightsByLayer.keys())):
labelsW.append(key)
histosW.append(allWeightsByLayer[key])
fig = plt.figure()
ax = fig.add_subplot()
plt.semilogy()
plt.legend(loc='upper left',fontsize=15,frameon=False)
bins = np.linspace(-1, 1, 1024)
ax.hist(histosW,bins,histtype='stepfilled',stacked=True,label=labelsW,color=colors)#, edgecolor='black')
ax.legend(frameon=False,loc='upper left')
axis = plt.gca()
plt.ylabel('Number of Weights')
plt.xlabel('Weights')
plt
```
%% Cell type:markdown id:cdd44876 tags:
This looks like expected! Now, lets compare the performance to that of the floating point model:
%% Cell type:code id:6fdf547a tags:
``` python
y_pred = baseline_model.predict(X_test, batch_size = 4096)
qy_pred = qmodel.predict(qX_test, batch_size = 4096)
```
%% Cell type:code id:49ea6534 tags:
``` python
from sklearn.metrics import roc_curve, roc_auc_score
assert(len(y_test) == len(y_pred) == len(qy_pred)), "Inconsistent predicted and true!"
fpr, tpr, thr = roc_curve(y_test, y_pred, pos_label=None, sample_weight=None, drop_intermediate=True)
roc_auc = roc_auc_score(y_test, y_pred)
qfpr, qtpr, qthr = roc_curve(y_test, qy_pred, pos_label=None, sample_weight=None, drop_intermediate=True)
qroc_auc = roc_auc_score(y_test, qy_pred)
```
%% Cell type:code id:41a6d24c tags:
``` python
# Lets also convert from FPR to L1 rate:
def totalMinBiasRate():
LHCfreq = 11245.6
nCollBunch = 2544
return LHCfreq * nCollBunch / 1e3 # in kHz
fpr *= totalMinBiasRate()
qfpr *= totalMinBiasRate()
```
%% Cell type:code id:a3d20287 tags:
``` python
# Lets plot it!
f, ax = plt.subplots(figsize=(8,6))
# plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')
ax.tick_params(axis='both', which='major', labelsize=14)
ax.tick_params(axis='both', which='minor', labelsize=14)
ax.set_xlim(0,100)
ax.plot(fpr, tpr, color='#7b3294', lw=2, ls='dashed', label=f'Baseline (AUC = {roc_auc:.5f})')
ax.plot(qfpr, qtpr, color='#008837', lw=2, label=f'Quantized+Pruned (AUC = {qroc_auc:.5f})')
ax.set_xlabel('L1 Rate (kHz)')
ax.set_ylabel('Signal efficiency')
ax.legend(loc="lower right")
ax.grid(True)
plt
```
%% Cell type:markdown id:2ba99f2a tags:
So it seems despite having reduced the numerical precision of the model and the input, as well as removing 50% of the model weights, we're doing pretty good! This can be tuned to get even better, by carefully adjusting the input precision and the model precision, especially increaseing the precision of the logit layer.
# Generating firmware with
<img src="images/hls4ml_logo.png" width="400"/>
Time to translate this model into HLS (which we will integrate in the emulator) and use to generate the vhdl to be integrated in the trigger firmware.
hls4ml seamlessly talks to QKeras, making our jobs way easier for us, but there is still some work for us to do to make sure we get good hardware model accuracy. Lets start!
There are a few things I already know in advance and would like my model to include:
- Be execuded fully parallel (=unrolled) to reach the lowest possible latency. We set the ReuseFactor=1 and Strategy=Latency
- The correct input precision
- The correct model output (that's something you have to figure out yourself!)
- Use "correct" precisions to make sure the hardware model performs the same as the software one. QKeras handles weights/biases and activation functions for us, but there are a few other parameters that need to be set by hand
For the final point, have a look at the following diagram:
<img src="images/hls4ml_precisions.png" width="400"/>
Whereas the $weight$ and $bias$ is set to its optimal value from the QKeras model, the accumulator $accum$ and $result$ is set to some default value that might not be optimal for a given model and might need tuning. Let's do a first attemt and compare the ROC curves:
%% Cell type:code id:9075991f tags:
``` python
import hls4ml
def print_dict(d, indent=0):
for key, value in d.items():
print(' ' * indent + str(key), end='')
if isinstance(value, dict):
print()
print_dict(value, indent + 1)
else:
print(':' + ' ' * (20 - len(key) - 2 * indent) + str(value))
config = hls4ml.utils.config_from_keras_model(qmodel, granularity='name')
config["Model"]["Strategy"] = "Latency"
config["Model"]["ReuseFactor"] = 1
inputPrecision = "ap_fixed<16,7,AP_RND,AP_SAT>" #Adding one bit for the sign, different definitions QKeras/Vivado
for layer in qmodel.layers:
if layer.__class__.__name__ in ["InputLayer"]:
config["LayerName"][layer.name]["Precision"] = inputPrecision
config["LayerName"]["output"]["Precision"]["result"] = "ap_fixed<13,2,AP_RND,AP_SAT>"
# If the logit layer is a "normal" Keras kayer and has not been quantized during the training,
# we need to be careful setting the accuracy. This can be done in the following way:
# config["LayerName"]["logits"]["Precision"]["weight"] = "ap_fixed<13,2,AP_RND,AP_SAT>"
# config["LayerName"]["logits"]["Precision"]["bias"] = "ap_fixed<13,2,AP_RND,AP_SAT>"
# config["LayerName"]["logits"]["Precision"]["accum"] = "ap_fixed<13,2,AP_RND,AP_SAT>"
# config["LayerName"]["logits"]["Precision"]["result"] = "ap_fixed<13,2,AP_RND,AP_SAT>"
print("-----------------------------------")
print_dict(config)
print("-----------------------------------")
hls_model = hls4ml.converters.convert_from_keras_model(qmodel,
hls_config=config,
io_type='io_parallel', #other option is io_stream
output_dir='L1TMLDemo_v1',
project_name='L1TMLDemo_v1',
part='xcu250-figd2104-2L-e', #Target FPGA, ideally you would use VU9P and VU13P that we use in L1T but they are not installed at lxplus, this one is close enought for this
clock_period=2.5, # Target frequency 1/2.5ns = 400 MHz
input_data_tb='qX_test.npy', # For co-simulation
output_data_tb='qy_test.npy',# For co-simulation
)
hls_model.compile()
```
%% Cell type:markdown id:9e00bf3f tags:
First, what does our newly created model look like?
%% Cell type:code id:fc990262 tags:
``` python
hls4ml.utils.plot_model(hls_model, show_shapes=True, show_precision=True, to_file=None)
```
%% Cell type:markdown id:011b95c4 tags:
Here you can see that the precision is what we set it to be in QKeras as well as what we set manually in the config. One thing to note is the different definitions used in QKeras and in ap_fixed:
- ```quantized_bits(8,0) -> ap_fixed<8,1>```
- ```quantized_relu(8,0) -> ap_ufixed<8,0>```
Also you can see that the defualt value for result/accu is set to $16,6$. This can also be tuned to more optimal values.
%% Cell type:code id:660f657e tags:
``` python
# Let's also run predict on the C++ implementation of our model and make sure it's the ~same as for the QKeras model:
# This is very slow for the C++ implementation of our model, so lets only do 1000 events for this
y_hls = hls_model.predict(np.ascontiguousarray(qX_test))
print(f"Truth labels:\n {y_test[17:27]}")
print(f"Qkeras prediction:\n {qy_pred[17:27]}")
print(f"HLS prediction:\n {y_hls[17:27]}")
```
%% Cell type:code id:ac7480a6 tags:
``` python
# Lets plot it!
hlsfpr, hlstpr, hlsthr = roc_curve(y_test, y_hls, pos_label=None, sample_weight=None, drop_intermediate=True)
hlsfpr *= totalMinBiasRate()
hlsroc_auc = roc_auc_score(y_test, y_hls)
f, ax = plt.subplots(figsize=(8,6))
# plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')
ax.tick_params(axis='both', which='major', labelsize=14)
ax.tick_params(axis='both', which='minor', labelsize=14)
ax.set_xlim(0,100)
ax.plot(fpr, tpr, color='#7b3294', lw=2, ls='dashed', label=f'Baseline (AUC = {roc_auc:.5f})')
ax.plot(qfpr, qtpr, color='#008837', lw=2, label=f'Quantized+Pruned (AUC = {qroc_auc:.5f})')
ax.plot(hlsfpr, hlstpr, color='#a6dba0', lw=2, ls='dotted', label=f'HLS Quantized+Pruned (AUC = {hlsroc_auc:.5f})')
ax.set_xlabel('L1 Rate (kHz)')
ax.set_ylabel('Signal efficiency')
ax.legend(loc="lower right")
ax.grid(True)
plt
```
%% Cell type:markdown id:00985383 tags:
Oh! That was easier than expected. If you see the accuracies differing significantly, it's a good idea to look into accumulator and reult precisions. Also with tools like $Trace$ and $Profiling$ that you can learn from in the [official hls4ml tutorial](https://github.com/fastmachinelearning/hls4ml-tutorial/blob/main/part2_advanced_config.ipynb) can be helpful! In this case, it doesnt seem like it's necessary. Now let's build it! Lets run C-synthesis (C++ to register-transfer level), Vivado logic synthesis (gate level representation) and co-simulation (send test vectors, do an exhaustive functional test of the implemented logic)
%% Cell type:code id:3d73b5aa tags:
``` python
print("Running synthesis!")
report = hls_model.build(csim=False, synth=True, vsynth=True, cosim=False)
```
%% Cell type:markdown id:3893c6cd tags:
Now, lets, look at the reports! The latency can be extracted from the C-synthesis report, and validated from the co-simulation report (where actual data is sent through the logic. The resource consumption can be extracted from the implementation report (Vivado logic synthesis) and is more accurate then what is quoted in the C-synthesis report:
%% Cell type:code id:dbc8b9f2 tags:
``` python
print("\nC synthesis report (latency estimate):")
print_dict(report["CSynthesisReport"])
#print_dict(report["CosimReport"]) Not working due to missing libc header sys/cdefs.h :(?
print("\nVivado synthesis report (resource estimates):")
print_dict(report["VivadoSynthReport"])
```
%% Cell type:markdown id:72c1723a tags:
A latency of $2.5\cdot16=40$ ns, that is not bad! Also, the network is using very little resources: 8k out of 1728k LUTs, 15 out of 12k DSPs. This is <1% of the total available resources. We have a set of HLS files that will be integrated into the CMSSW emulator (```L1TMLDemo_v1/firmware/```) and VHDL that will be integrated into the mGT firmware (```L1TMLDemo_v1/myproject_prj/solution1/impl/vhdl/```). That's next!
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment