diff --git a/etx4velo/pipeline/utils/modelutils/basemodel.py b/etx4velo/pipeline/utils/modelutils/basemodel.py index d3da99944bdad27cc36493e36e2244c58c0664b3..43ae085d00a0f0391e9787d600687882fd9c04cf 100644 --- a/etx4velo/pipeline/utils/modelutils/basemodel.py +++ b/etx4velo/pipeline/utils/modelutils/basemodel.py @@ -19,7 +19,6 @@ from torch_geometric.loader import DataLoader from utils.commonutils.cfeatures import get_input_features from utils.loaderutils.dataiterator import LazyDatasetBase -from utils.modelutils.export import convert_model_to_fp16 def check_and_discard(s: typing.Set[typing.Any], element: typing.Any) -> bool: @@ -640,6 +639,8 @@ class ModelBase(LightningModule): print("Model was exported to", os.path.abspath(outpath)) if use_fp16: + from utils.modelutils.export import convert_model_to_fp16 + convert_model_to_fp16(outpath) diff --git a/etx4velo/pipeline/utils/modelutils/export.py b/etx4velo/pipeline/utils/modelutils/export.py index c90d34212c980d75080d7ad889ba14e9d770a678..75dbc63fb2f456529e241e882765c794df0ee1c5 100644 --- a/etx4velo/pipeline/utils/modelutils/export.py +++ b/etx4velo/pipeline/utils/modelutils/export.py @@ -53,6 +53,7 @@ def convert_model_to_fp16(inpath: str, outpath: str | None = None) -> None: model_fp16 = float16.convert_float_to_float16(model) onnx.save(model_fp16, outpath) + class TRTScatterAddOp(torch.autograd.Function): """A fake scatter add operator for ONNX export, used with a custom TensorRT plugin that implements the scatter add operation.