diff --git a/integration-tests/src/python-integration-test/python/nxcals/integrationtests/data_access/extraction_test.py b/integration-tests/src/python-integration-test/python/nxcals/integrationtests/data_access/extraction_test.py index 608b8f75e0e70b5b03585fcd9b88ae91fa9da861..08e71a9ed649f9f2efc489628440582be1cae769 100644 --- a/integration-tests/src/python-integration-test/python/nxcals/integrationtests/data_access/extraction_test.py +++ b/integration-tests/src/python-integration-test/python/nxcals/integrationtests/data_access/extraction_test.py @@ -10,6 +10,7 @@ from nxcals.api.extraction.data.builders_expanded import DataQuery as DataQueryE from pyspark.sql.functions import col from pyspark.sql.utils import IllegalArgumentException +from nxcals.api.utils.extraction.array_utils import ArrayUtils from . import ( PySparkIntegrationTest, metadata_utils, @@ -17,6 +18,7 @@ from . import ( log, ) from .. import MONITORING_SYSTEM_NAME, MONITORING_DEVICE_KEY +from pyspark.sql.types import ArrayType, IntegerType, FloatType def no_data_losses_day(): @@ -705,3 +707,31 @@ class ShouldExecuteAsExpanded(PySparkIntegrationTest): ) self.assertTrue(isinstance(list_of_dses, list)) self.assertTrue(len(list_of_dses) > 1) + + +class ShouldExtractArrayColumns(PySparkIntegrationTest): + def runTest(self): + df = DataQuery.getForEntities( + spark_session, + system=MONITORING_SYSTEM_NAME, + start_time=no_data_losses_day(), + end_time=end_of_no_data_losses_day(), + entity_queries=[ + EntityQuery( + {MONITORING_DEVICE_KEY: "NXCALS_MONITORING_DEV6"}, + ), + ], + ) + + field_name = "intArrayField" + df2 = ArrayUtils.reshape(df, [field_name]) + self.assertEqual(df2.schema[field_name].dataType, ArrayType(IntegerType())) + + field_name = "floatArray2DField2" + df3 = ArrayUtils.reshape(df, [field_name]) + self.assertEqual( + df3.schema[field_name].dataType, ArrayType(ArrayType(FloatType())) + ) + + df4 = ArrayUtils.reshape(df) + self.assertTrue(len(ArrayUtils._extract_array_fields(df4.schema)) == 0) \ No newline at end of file diff --git a/integration-tests/src/python-integration-test/python/nxcals/integrationtests/pytimber_tests/test_pytimber.py b/integration-tests/src/python-integration-test/python/nxcals/integrationtests/pytimber_tests/test_pytimber.py index adb3b4324da487c2350bc792e390b0e7ce7af129..fcc236cc1975deefb9e1aa762130a2b634dad88c 100644 --- a/integration-tests/src/python-integration-test/python/nxcals/integrationtests/pytimber_tests/test_pytimber.py +++ b/integration-tests/src/python-integration-test/python/nxcals/integrationtests/pytimber_tests/test_pytimber.py @@ -283,7 +283,7 @@ class TestAligned: result = ldb.getAligned(pattern_or_list, START_TIME, END_TIME) (_, master_data_values) = ldb.getVariable(master_var_name, START_TIME, END_TIME) - assert np.array_equal(result[master_var_name], master_data_values) + assert all(np.array_equal(arr1, arr2) for arr1, arr2 in zip(result[master_var_name], master_data_values)) @pytest.mark.skip(reason="no way of currently testing this") def test_on_real_data(self, ldb: LoggingDB) -> None: @@ -779,7 +779,7 @@ class TestExtractionMethods: variable_name, START_TIME, END_TIME, unixtime=False )[variable_name] assert np.array_equal(timestamps, timestamps_1) - assert np.array_equal(values, values_1) + assert all(np.array_equal(arr1, arr2) for arr1, arr2 in zip(values, values_1)) class TestTimestamps: diff --git a/python/extraction-api-python3/nxcals/api/utils/__init__.py b/python/extraction-api-python3/nxcals/api/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/python/extraction-api-python3/nxcals/api/utils/extraction/__init__.py b/python/extraction-api-python3/nxcals/api/utils/extraction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/python/extraction-api-python3/nxcals/api/utils/extraction/array_utils.py b/python/extraction-api-python3/nxcals/api/utils/extraction/array_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..08789aac7d3f332460e3a6edbd91c085865dcec8 --- /dev/null +++ b/python/extraction-api-python3/nxcals/api/utils/extraction/array_utils.py @@ -0,0 +1,146 @@ +from typing import List + +import numpy as np +import pyspark.sql as sp + +from pyspark.sql.functions import col, udf +from pyspark.sql.types import ArrayType, StructType, DataType + +ARRAY_ELEMENTS_FIELD_NAME = "elements" +ARRAY_DIMENSIONS_FIELD_NAME = "dimensions" + + +class ArrayUtils: + @staticmethod + def _extract_array_fields(schema: StructType) -> list[str]: + """ + Extracts field names that contain structured array data (with 'elements' and 'dimensions'). + + :param schema: Spark StructType schema + :return: List of field names that represent arrays + """ + return [ + field.name + for field in schema.fields + if isinstance(field.dataType, StructType) + and {ARRAY_ELEMENTS_FIELD_NAME, ARRAY_DIMENSIONS_FIELD_NAME}.issubset( + set(field.dataType.names) + ) + ] + + @staticmethod + def _validate_array_fields(schema, value_fields: List[str]): + """ + Validates if the given columns exists in the schema and have the required structure. + + :param schema: Spark StructType schema + :param value_fields: Column names to validate + :raises ValueError: If column is missing or does not conform to expected structure + """ + for value_field in value_fields: + try: + field = schema[value_field] + except KeyError: + raise ValueError(f"Field '{value_field}' does not exist in the schema") + + if not isinstance(field.dataType, StructType): + raise ValueError(f"Field '{value_field}' is not a struct type") + + field_names = set(field.dataType.names) + if field_names != {"elements", "dimensions"}: + raise ValueError( + f"Field '{value_field}' must contain both 'elements' and 'dimensions' fields" + ) + + @staticmethod + def _reshape_array(elements, dimensions): + """ + Reshapes a flat list into a multidimensional NumPy array. + + :param elements: List[int] - The flattened array. + :param dimensions: List[int] - The shape of the desired multidimensional array. + :return: List - Reshaped array converted back to a nested list. + """ + if len(dimensions) == 1: + return elements + + np_array = np.array(elements) + + total_elements = np.prod(dimensions) + if len(elements) != total_elements: + raise ValueError( + f"Cannot reshape array: {len(elements)} elements cannot fit into shape {dimensions}." + ) + + reshaped_array = np_array.reshape(dimensions).tolist() + return reshaped_array + + @staticmethod + def _get_nested_array_type(depth: int, base_type) -> DataType: + """ + Recursively generates a nested ArrayType based on the depth. + + :param depth: Number of dimensions. + :param base_type: The base type (default: IntegerType). + :return: Nested ArrayType. + """ + array_type = base_type + for _ in range(depth): + array_type = ArrayType(array_type) + return array_type + + @staticmethod + def reshape( + df: sp.DataFrame, array_columns: list[str] | None = None + ) -> sp.DataFrame: + """ + Extracts and reshape array columns for easier processing. + + :param df: Input Spark DataFrame + :param array_columns: List of column names containing structured array data + :return: Transformed DataFrame with extracted 'elements' and 'dimensions' columns + """ + schema = df.schema + + if not array_columns: + array_columns = ArrayUtils._extract_array_fields(schema) + else: + ArrayUtils._validate_array_fields(schema, array_columns) + + first_row = df.first() + for column in array_columns: + base_array_type = ( + schema[column].dataType[ARRAY_ELEMENTS_FIELD_NAME].dataType + ) + + first_row_dimension = first_row[column][ARRAY_DIMENSIONS_FIELD_NAME] + if first_row_dimension: + dimension_length = len(first_row_dimension) + + if dimension_length == 1: + reshape_array_udf = udf( + lambda elements, dims: ArrayUtils._reshape_array( + elements, dims + ), + base_array_type, + ) + else: + nested_array_type = ArrayUtils._get_nested_array_type( + dimension_length - 1, base_array_type + ) + reshape_array_udf = udf( + lambda elements, dims: ArrayUtils._reshape_array( + elements, dims + ), + nested_array_type, + ) + + df = df.withColumn( + column, + reshape_array_udf( + col(f"{column}.{ARRAY_ELEMENTS_FIELD_NAME}"), + col(f"{column}.{ARRAY_DIMENSIONS_FIELD_NAME}"), + ), + ) + + return df diff --git a/python/extraction-api-python3/tests/nxcals/api/utils/__init__.py b/python/extraction-api-python3/tests/nxcals/api/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/python/extraction-api-python3/tests/nxcals/api/utils/extraction/__init__.py b/python/extraction-api-python3/tests/nxcals/api/utils/extraction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/python/extraction-api-python3/tests/nxcals/api/utils/extraction/array_utils_test.py b/python/extraction-api-python3/tests/nxcals/api/utils/extraction/array_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..abd9a1c81be8be7d907144d2445c2048cfcddb24 --- /dev/null +++ b/python/extraction-api-python3/tests/nxcals/api/utils/extraction/array_utils_test.py @@ -0,0 +1,104 @@ +from unittest import TestCase +from pyspark.sql.types import IntegerType, ArrayType, StructType, StructField +from nxcals.api.utils.extraction.array_utils import ( + ArrayUtils, + ARRAY_ELEMENTS_FIELD_NAME, + ARRAY_DIMENSIONS_FIELD_NAME, +) + + +class should_build_query(TestCase): + def test_extract_array_fields(self): + schema = StructType( + [ + StructField("normal_col", IntegerType(), True), + StructField( + "array_field", + StructType( + [ + StructField( + ARRAY_ELEMENTS_FIELD_NAME, + ArrayType(IntegerType()), + True, + ), + StructField( + ARRAY_DIMENSIONS_FIELD_NAME, + ArrayType(IntegerType()), + True, + ), + ] + ), + True, + ), + ] + ) + + extracted_fields = ArrayUtils._extract_array_fields(schema) + self.assertEqual(extracted_fields, ["array_field"]) + + def test_get_nested_array_type(self): + nested_type = ArrayUtils._get_nested_array_type(3, IntegerType()) + self.assertIsInstance(nested_type, ArrayType) + self.assertIsInstance(nested_type.elementType, ArrayType) + self.assertIsInstance(nested_type.elementType.elementType, ArrayType) + self.assertIsInstance( + nested_type.elementType.elementType.elementType, IntegerType + ) + + def test_validate_array_fields_valid(self): + schema = StructType( + [ + StructField( + "valid_array", + StructType( + [ + StructField( + ARRAY_ELEMENTS_FIELD_NAME, + ArrayType(IntegerType()), + True, + ), + StructField( + ARRAY_DIMENSIONS_FIELD_NAME, + ArrayType(IntegerType()), + True, + ), + ] + ), + True, + ) + ] + ) + + ArrayUtils._validate_array_fields(schema, ["valid_array"]) + + def test_validate_array_fields_invalid(self): + schema = StructType( + [ + StructField( + "invalid_array", + StructType( + [ + StructField( + ARRAY_ELEMENTS_FIELD_NAME, + ArrayType(IntegerType()), + True, + ) + ] + ), + True, + ) + ] + ) + + with self.assertRaises( + ValueError, msg="must contain both 'elements' and 'dimensions'" + ): + ArrayUtils._validate_array_fields(schema, ["invalid_array"]) + + def test_reshape_array_valid(self): + reshaped = ArrayUtils._reshape_array([1, 2, 3, 4], [2, 2]) + self.assertEqual(reshaped, [[1, 2], [3, 4]]) + + def test_reshape_array_1D(self): + reshaped = ArrayUtils._reshape_array([1, 2, 3, 4], [4]) + self.assertEqual(reshaped, [1, 2, 3, 4])