From ddcce7225cf63710a8560fe7d060cac3c08b46f8 Mon Sep 17 00:00:00 2001
From: Leon Teichroeb <leon.teichroeb@cern.ch>
Date: Thu, 12 Dec 2024 18:06:46 +0100
Subject: [PATCH] Cleaned up generate_ref_data.py

---
 utils/generate_ref_data.py | 214 ++++++++++---------------------------
 1 file changed, 58 insertions(+), 156 deletions(-)

diff --git a/utils/generate_ref_data.py b/utils/generate_ref_data.py
index d64b45f..f5db0a7 100644
--- a/utils/generate_ref_data.py
+++ b/utils/generate_ref_data.py
@@ -1,164 +1,66 @@
-import  itertools
-import  yaml
-import  numpy as np
-from 	steammaterials.STEAM_materials import STEAM_materials
-from    datetime import datetime
-from    pathlib import Path
-import  os
-import  pandas as pd
+#!/usr/bin/env python3
+import itertools
+import yaml
+import numpy as np
+from   steammaterials.STEAM_materials import STEAM_materials
+from   datetime import datetime
+from   pathlib import Path
+import os
+import sys
+import pandas as pd
 
 
-################ FILL THIS IN #################
-csv_file = pd.read_csv(Path('../all_material_functions.csv'))
+def create_all_ref_files(output_path=None):
+    repository_root = Path(os.path.dirname(__file__)).parent
+    if output_path is None:
+        output_path = repository_root / 'tests' / 'ref_data'
+    csv_file = pd.read_csv(repository_root / 'all_material_functions.csv')
+   
+    errors = 0
+    for i, row in csv_file.iterrows():
+        if row['make_ref']:
+            try:
+                create_ref_file(row, output_path)
+            except Exception as e:
+                print(f'Failed to generate reference file for {row["ref_name"]}')
+                print(f'Failed with error: {e}')
+                errors += 1
+    if errors:
+        print(f'!!! During generation, {errors} errors occurred !!!')
 
-for i in range(len(csv_file)):
-    if csv_file.iloc[i]['make_ref']:
-        input_parameters = int(csv_file.iloc[i]['input_parameters'])
-        gen_by = csv_file.iloc[i]['ref_generated_by']
-        ref_name = csv_file.iloc[i]['ref_name']
 
+def create_ref_file(row, output_path):
+    """
+    Creates a file to act as reference for the expected output of the given material function.
+    Expects a dictionary-like `row` containing the number of 'input_parameters', the material name
+    'ref_name', author 'ref_generated_by' and inputs with 'input{i}'.
+    """
+    parameter_count = int(row['input_parameters'])
+    ref_name = row['ref_name']
+    if parameter_count < 1:
+        raise ValueError('Incorrect number of input parameters for {ref_name}')
 
-        # you = 'Tim Mulder'
-        # # material_function = 'CFUN_rhoHast_v2'
-        # # input_parameters = 1
-        # material_function = 'CFUN_rhoCu_NIST_v2'
-        # input_parameters = 3
-        ###############################################
+    inputs = [eval(row[f'input{i+1}']) for i in range(parameter_count)]
+    input_mat = np.array(list(itertools.product(*inputs))).T
+    out = STEAM_materials(ref_name, input_mat.shape[0], input_mat.shape[1]).evaluate(input_mat)
 
-        if input_parameters == 1:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            numpy2d = input1.reshape((1, len(input1)))
-            input_list = np.array([[i] for i in input1])
-        elif input_parameters == 2:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input_list = np.array([i for i in itertools.product(input1, input2)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1]))
-        elif input_parameters == 3:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input3 = eval(csv_file.iloc[i]['input3'])
-            input_list = np.array([i for i in itertools.product(input1, input2, input3)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1], input_list[:, 2]))
-        elif input_parameters == 4:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input3 = eval(csv_file.iloc[i]['input3'])
-            input4 = eval(csv_file.iloc[i]['input4'])
-            input_list = np.array([i for i in itertools.product(input1, input2, input3, input4)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1], input_list[:, 2], input_list[:, 3]))
-        elif input_parameters == 5:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input3 = eval(csv_file.iloc[i]['input3'])
-            input4 = eval(csv_file.iloc[i]['input4'])
-            input5 = eval(csv_file.iloc[i]['input5'])
-            input_list = np.array([i for i in itertools.product(input1, input2, input3, input4, input5)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1], input_list[:, 2], input_list[:, 3], input_list[:, 4]))
-        elif input_parameters == 6:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input3 = eval(csv_file.iloc[i]['input3'])
-            input4 = eval(csv_file.iloc[i]['input4'])
-            input5 = eval(csv_file.iloc[i]['input5'])
-            input6 = eval(csv_file.iloc[i]['input6'])
-            input_list = np.array([i for i in itertools.product(input1, input2, input3, input4, input5, input6)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1], input_list[:, 2], input_list[:, 3], input_list[:, 4], input_list[:, 5]))
-        elif input_parameters == 7:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input3 = eval(csv_file.iloc[i]['input3'])
-            input4 = eval(csv_file.iloc[i]['input4'])
-            input5 = eval(csv_file.iloc[i]['input5'])
-            input6 = eval(csv_file.iloc[i]['input6'])
-            input7 = eval(csv_file.iloc[i]['input7'])
-            input_list = np.array([i for i in itertools.product(input1, input2, input3, input4, input5, input6, input7)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1], input_list[:, 2], input_list[:, 3], input_list[:, 4], input_list[:, 5], input_list[:, 6]))
-        elif input_parameters == 8:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input3 = eval(csv_file.iloc[i]['input3'])
-            input4 = eval(csv_file.iloc[i]['input4'])
-            input5 = eval(csv_file.iloc[i]['input5'])
-            input6 = eval(csv_file.iloc[i]['input6'])
-            input7 = eval(csv_file.iloc[i]['input7'])
-            input8 = eval(csv_file.iloc[i]['input8'])
-            input_list = np.array([i for i in itertools.product(input1, input2, input3, input4, input5, input6, input7, input8)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1], input_list[:, 2], input_list[:, 3], input_list[:, 4], input_list[:, 5], input_list[:, 6], input_list[:, 7]))
-        elif input_parameters == 9:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input3 = eval(csv_file.iloc[i]['input3'])
-            input4 = eval(csv_file.iloc[i]['input4'])
-            input5 = eval(csv_file.iloc[i]['input5'])
-            input6 = eval(csv_file.iloc[i]['input6'])
-            input7 = eval(csv_file.iloc[i]['input7'])
-            input8 = eval(csv_file.iloc[i]['input8'])
-            input9 = eval(csv_file.iloc[i]['input9'])
-            input_list = np.array([i for i in itertools.product(input1, input2, input3, input4, input5, input6, input7, input8, input9)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1], input_list[:, 2], input_list[:, 3], input_list[:, 4], input_list[:, 5], input_list[:, 6], input_list[:, 7], input_list[:, 8]))
-        elif input_parameters == 10:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input3 = eval(csv_file.iloc[i]['input3'])
-            input4 = eval(csv_file.iloc[i]['input4'])
-            input5 = eval(csv_file.iloc[i]['input5'])
-            input6 = eval(csv_file.iloc[i]['input6'])
-            input7 = eval(csv_file.iloc[i]['input7'])
-            input8 = eval(csv_file.iloc[i]['input8'])
-            input9 = eval(csv_file.iloc[i]['input9'])
-            input10 = eval(csv_file.iloc[i]['input10'])
-            input_list = np.array([i for i in itertools.product(input1, input2, input3, input4, input5, input6, input7, input8, input9, input10)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1], input_list[:, 2], input_list[:, 3], input_list[:, 4], input_list[:, 5], input_list[:, 6], input_list[:, 7], input_list[:, 8], input_list[:, 9]))
-        elif input_parameters == 11:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input3 = eval(csv_file.iloc[i]['input3'])
-            input4 = eval(csv_file.iloc[i]['input4'])
-            input5 = eval(csv_file.iloc[i]['input5'])
-            input6 = eval(csv_file.iloc[i]['input6'])
-            input7 = eval(csv_file.iloc[i]['input7'])
-            input8 = eval(csv_file.iloc[i]['input8'])
-            input9 = eval(csv_file.iloc[i]['input9'])
-            input10 = eval(csv_file.iloc[i]['input10'])
-            input11 = eval(csv_file.iloc[i]['input11'])
-            input_list = np.array([i for i in itertools.product(input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1], input_list[:, 2], input_list[:, 3], input_list[:, 4], input_list[:, 5], input_list[:, 6], input_list[:, 7], input_list[:, 8], input_list[:, 9], input_list[:, 10]))
-        elif input_parameters == 12:
-            input1 = eval(csv_file.iloc[i]['input1'])
-            input2 = eval(csv_file.iloc[i]['input2'])
-            input3 = eval(csv_file.iloc[i]['input3'])
-            input4 = eval(csv_file.iloc[i]['input4'])
-            input5 = eval(csv_file.iloc[i]['input5'])
-            input6 = eval(csv_file.iloc[i]['input6'])
-            input7 = eval(csv_file.iloc[i]['input7'])
-            input8 = eval(csv_file.iloc[i]['input8'])
-            input9 = eval(csv_file.iloc[i]['input9'])
-            input10 = eval(csv_file.iloc[i]['input10'])
-            input11 = eval(csv_file.iloc[i]['input11'])
-            input12 = eval(csv_file.iloc[i]['input12'])
-            input_list = np.array([i for i in itertools.product(input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11, input12)])
-            numpy2d = np.vstack((input_list[:, 0], input_list[:, 1], input_list[:, 2], input_list[:, 3], input_list[:, 4], input_list[:, 5], input_list[:, 6], input_list[:, 7], input_list[:, 8], input_list[:, 9], input_list[:, 10], input_list[:, 11]))
-        else:
-            print('Wrong amount of input parameters')
+    dictionary_output = {
+        'author': row['ref_generated_by'],
+        'date': datetime.today().strftime('%d.%m.%Y'),
+        'material_function': ref_name,
+        'input': input_mat.T.tolist(),
+        'output': [i.item() for i in out],
+        'input_parameters': parameter_count,
+    }
+    try:
+        with (output_path / f'{ref_name}.yaml').open("x") as f:
+            yaml.dump(dictionary_output, f, default_flow_style=False)
+    except FileExistsError:
+        print(f'Reference file for {ref_name} already exists. You cannot overwrite it.')
 
-        ###############################################
-        if input_parameters >= 1:
-            out = STEAM_materials(ref_name, numpy2d.shape[0], numpy2d.shape[1]).evaluate(numpy2d)
-        else:
-            print('Wrong amount of input parameters')
 
-        ############### Making the yaml ################
-        dictionary_output = {
-            'author': gen_by,
-            'date': datetime.today().strftime('%d.%m.%Y'),
-            'material_function': ref_name,
-            'input': [[i.item() for i in k] for k in input_list],
-            'output': [i.item() for i in out],
-            'input_parameters': input_parameters,
-        }
-        try:
-            with Path(os.path.dirname(os.path.dirname(__file__)) + os.sep + 'tests' + os.sep + 'ref_data' + os.sep + ref_name+".yaml").open("x") as f:
-                yaml.dump(dictionary_output, f, default_flow_style=False)
-        except:
-            print('Reference file for ' + ref_name + ' already excists. You cannot overwrite it.')
+if __name__ == '__main__':
+    if len(sys.argv) > 1:
+        create_all_ref_files(Path(sys.argv[1]))
+    else:
+        create_all_ref_files()
-- 
GitLab