diff --git a/integration-tests/src/python-integration-test/python/nxcals/integrationtests/data_access/extraction_test.py b/integration-tests/src/python-integration-test/python/nxcals/integrationtests/data_access/extraction_test.py index 608b8f75e0e70b5b03585fcd9b88ae91fa9da861..74e571ed5310e681533d16c0c938185692f9d7b1 100644 --- a/integration-tests/src/python-integration-test/python/nxcals/integrationtests/data_access/extraction_test.py +++ b/integration-tests/src/python-integration-test/python/nxcals/integrationtests/data_access/extraction_test.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta import pkg_resources + from nxcals.api.extraction.data.builders import ( DataQuery, EntityQuery, @@ -10,6 +11,7 @@ from nxcals.api.extraction.data.builders_expanded import DataQuery as DataQueryE from pyspark.sql.functions import col from pyspark.sql.utils import IllegalArgumentException +from nxcals.api.utils.extraction.array_utils import ArrayUtils from . import ( PySparkIntegrationTest, metadata_utils, @@ -17,6 +19,11 @@ from . import ( log, ) from .. import MONITORING_SYSTEM_NAME, MONITORING_DEVICE_KEY +from pyspark.sql.types import ArrayType, FloatType, IntegerType +from pyspark.sql.types import StructType, StructField, ArrayType, IntegerType +from pyspark.sql import Row +from nxcals.api.utils.constants import ARRAY_ELEMENTS_FIELD_NAME, ARRAY_DIMENSIONS_FIELD_NAME + def no_data_losses_day(): @@ -705,3 +712,72 @@ class ShouldExecuteAsExpanded(PySparkIntegrationTest): ) self.assertTrue(isinstance(list_of_dses, list)) self.assertTrue(len(list_of_dses) > 1) + + +class ShouldExtractArrayColumns(PySparkIntegrationTest): + def runTest(self): + df = DataQuery.getForEntities( + spark_session, + system=MONITORING_SYSTEM_NAME, + start_time=no_data_losses_day(), + end_time=end_of_no_data_losses_day(), + entity_queries=[ + EntityQuery( + {MONITORING_DEVICE_KEY: "NXCALS_MONITORING_DEV6"}, + ), + ], + ) + + field_name = "intArrayField" + df2 = ArrayUtils.reshape(df, [field_name]) + self.assertEqual(df2.schema[field_name].dataType, ArrayType(IntegerType())) + + field_name = "floatArray2DField2" + df3 = ArrayUtils.reshape(df, [field_name]) + self.assertEqual(df3.schema[field_name].dataType, ArrayType(ArrayType(FloatType()))) + + df4 = ArrayUtils.reshape(df) + self.assertTrue(len(ArrayUtils._extract_array_fields(df4.schema))==0) + +class ShouldUdfRaiseExceptionOnInvalidDimensions(PySparkIntegrationTest): + + def test_changing_dimensions(self): + schema = StructType( + [ + StructField( + "array_field", + StructType( + [ + StructField( + ARRAY_ELEMENTS_FIELD_NAME, + ArrayType(IntegerType()), + ), + StructField( + ARRAY_DIMENSIONS_FIELD_NAME, + ArrayType(IntegerType()), + ), + ] + ), + True, + ) + ] + ) + + data = [ + Row(array_field=Row( + elements=[1, 2, 3, 4, 5, 6], + dimensions=[2, 3] + )), + Row(array_field=Row( + elements=[7, 8, 9], + dimensions=[3, 1] + )) + ] + + df = spark_session.createDataFrame(data, schema) + err_msg = None + try: + ArrayUtils.reshape(df).collect() + except Exception as e: + err_msg = e + self.assertTrue("cannot reshape array" in str(err_msg)) \ No newline at end of file diff --git a/python/extraction-api-python3/build.gradle b/python/extraction-api-python3/build.gradle index e1c1ed6a1f652b049ac586ae218def54279634ca..d2ab9abf0692f2a596152e9eb9fe988f20c4ebb0 100644 --- a/python/extraction-api-python3/build.gradle +++ b/python/extraction-api-python3/build.gradle @@ -1,5 +1,6 @@ ext { - coreDependencies = ["numpy~=$numpyVersion"] + coreDependencies = ["numpy~=$numpyVersion", + "pandas~=$pandasVersion"] testDependencies = [ "coverage~=$coverageVersion", "pytest~=$pytestVersion", diff --git a/python/extraction-api-python3/nxcals/api/utils/__init__.py b/python/extraction-api-python3/nxcals/api/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/python/extraction-api-python3/nxcals/api/utils/constants.py b/python/extraction-api-python3/nxcals/api/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4a241e33be009a6c0b8e018b6d24a3a9ab0659 --- /dev/null +++ b/python/extraction-api-python3/nxcals/api/utils/constants.py @@ -0,0 +1,2 @@ +ARRAY_ELEMENTS_FIELD_NAME = "elements" +ARRAY_DIMENSIONS_FIELD_NAME = "dimensions" diff --git a/python/extraction-api-python3/nxcals/api/utils/extraction/__init__.py b/python/extraction-api-python3/nxcals/api/utils/extraction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/python/extraction-api-python3/nxcals/api/utils/extraction/array_utils.py b/python/extraction-api-python3/nxcals/api/utils/extraction/array_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..957e442b567cc6e4dd35664ea4c2d3585ca5ab42 --- /dev/null +++ b/python/extraction-api-python3/nxcals/api/utils/extraction/array_utils.py @@ -0,0 +1,130 @@ +from typing import List + +import numpy as np +import pandas as pd +from pyspark.sql.functions import col, pandas_udf +from pyspark.sql.types import ArrayType, StructType, DataType +from pyspark.sql import DataFrame + +from nxcals.api.utils.constants import ( + ARRAY_ELEMENTS_FIELD_NAME, + ARRAY_DIMENSIONS_FIELD_NAME, +) + + +class ArrayUtils: + @staticmethod + def _extract_array_fields(schema: StructType) -> list: + """ + Extracts field names that contain structured array data (with 'elements' and 'dimensions'). + + :param schema: Spark StructType schema + :return: List of field names that represent arrays + """ + return [ + field.name + for field in schema.fields + if isinstance(field.dataType, StructType) + and {ARRAY_ELEMENTS_FIELD_NAME, ARRAY_DIMENSIONS_FIELD_NAME}.issubset( + set(field.dataType.names) + ) + ] + + @staticmethod + def _validate_array_fields(schema, value_fields: List[str]): + """ + Validates if the given columns exists in the schema and have the required structure. + + :param schema: Spark StructType schema + :param value_fields: Column names to validate + :raises ValueError: If column is missing or does not conform to expected structure + """ + for value_field in value_fields: + try: + field = schema[value_field] + except KeyError: + raise ValueError(f"Field '{value_field}' does not exist in the schema") + + if not isinstance(field.dataType, StructType): + raise ValueError(f"Field '{value_field}' is not a struct type") + + field_names = set(field.dataType.names) + if field_names != {"elements", "dimensions"}: + raise ValueError( + f"Field '{value_field}' must contain both 'elements' and 'dimensions' fields" + ) + + @staticmethod + def _reshape_array(elements: pd.Series, dimensions: pd.Series) -> pd.Series: + """ + Reshapes a flat list into a multidimensional NumPy array for Pandas UDF. + + :param elements: Pandas Series of lists (flattened arrays). + :param dimensions: Pandas Series of lists (dimensions for reshaping). + :return: Pandas Series of reshaped lists. + """ + reshaped = [] + for e, d in zip(elements, dimensions): + if len(d) == 1: + reshaped.append(e) + else: + reshaped.append(np.array(e).reshape(d).tolist()) + + return pd.Series(reshaped) + + @staticmethod + def _get_nested_array_type(depth: int, base_type) -> DataType: + """ + Recursively generates a nested ArrayType based on the depth. + + :param depth: Number of dimensions. + :param base_type: The base type (default: IntegerType). + :return: Nested ArrayType. + """ + array_type = base_type + for _ in range(depth): + array_type = ArrayType(array_type) + return array_type + + @staticmethod + def reshape(df: DataFrame, array_columns: list = None) -> DataFrame: + """ + Extracts and reshape array columns for easier processing. + + :param df: Input Spark DataFrame + :param array_columns: List of column names containing structured array data + :return: Transformed DataFrame with extracted 'elements' and 'dimensions' columns + """ + schema = df.schema + + if not array_columns: + array_columns = ArrayUtils._extract_array_fields(schema) + else: + ArrayUtils._validate_array_fields(schema, array_columns) + + first_row = df.first() + for column in array_columns: + base_array_type = ( + schema[column].dataType[ARRAY_ELEMENTS_FIELD_NAME].dataType + ) + + first_row_dimension = first_row[column][ARRAY_DIMENSIONS_FIELD_NAME] + if first_row_dimension: + dimension_length = len(first_row_dimension) + nested_array_type = ArrayUtils._get_nested_array_type( + dimension_length - 1, base_array_type + ) + + reshape_array_udf = pandas_udf( + ArrayUtils._reshape_array, nested_array_type + ) + + df = df.withColumn( + column, + reshape_array_udf( + col(f"{column}.{ARRAY_ELEMENTS_FIELD_NAME}"), + col(f"{column}.{ARRAY_DIMENSIONS_FIELD_NAME}"), + ), + ) + + return df diff --git a/python/extraction-api-python3/tests/nxcals/api/utils/__init__.py b/python/extraction-api-python3/tests/nxcals/api/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/python/extraction-api-python3/tests/nxcals/api/utils/extraction/__init__.py b/python/extraction-api-python3/tests/nxcals/api/utils/extraction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/python/extraction-api-python3/tests/nxcals/api/utils/extraction/array_utils_test.py b/python/extraction-api-python3/tests/nxcals/api/utils/extraction/array_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..79608c846e87ec464cce13eb8bd690db226ea918 --- /dev/null +++ b/python/extraction-api-python3/tests/nxcals/api/utils/extraction/array_utils_test.py @@ -0,0 +1,112 @@ +import pandas as pd + +from unittest import TestCase + +from pyspark.sql.types import IntegerType, ArrayType, StructType, StructField + +from nxcals.api.utils.constants import ( + ARRAY_ELEMENTS_FIELD_NAME, + ARRAY_DIMENSIONS_FIELD_NAME, +) +from nxcals.api.utils.extraction.array_utils import ArrayUtils + + +class should_build_query(TestCase): + def test_extract_array_fields(self): + schema = StructType( + [ + StructField("normal_col", IntegerType(), True), + StructField( + "array_field", + StructType( + [ + StructField( + ARRAY_ELEMENTS_FIELD_NAME, + ArrayType(IntegerType()), + True, + ), + StructField( + ARRAY_DIMENSIONS_FIELD_NAME, + ArrayType(IntegerType()), + True, + ), + ] + ), + True, + ), + ] + ) + + extracted_fields = ArrayUtils._extract_array_fields(schema) + self.assertEqual(extracted_fields, ["array_field"]) + + def test_get_nested_array_type(self): + nested_type = ArrayUtils._get_nested_array_type(3, IntegerType()) + self.assertIsInstance(nested_type, ArrayType) + self.assertIsInstance(nested_type.elementType, ArrayType) + self.assertIsInstance(nested_type.elementType.elementType, ArrayType) + self.assertIsInstance( + nested_type.elementType.elementType.elementType, IntegerType + ) + + def test_validate_array_fields_valid(self): + schema = StructType( + [ + StructField( + "valid_array", + StructType( + [ + StructField( + ARRAY_ELEMENTS_FIELD_NAME, + ArrayType(IntegerType()), + True, + ), + StructField( + ARRAY_DIMENSIONS_FIELD_NAME, + ArrayType(IntegerType()), + True, + ), + ] + ), + True, + ) + ] + ) + + ArrayUtils._validate_array_fields(schema, ["valid_array"]) + + def test_validate_array_fields_invalid(self): + schema = StructType( + [ + StructField( + "invalid_array", + StructType( + [ + StructField( + ARRAY_ELEMENTS_FIELD_NAME, + ArrayType(IntegerType()), + True, + ) + ] + ), + True, + ) + ] + ) + + with self.assertRaises( + ValueError, msg="must contain both 'elements' and 'dimensions'" + ): + ArrayUtils._validate_array_fields(schema, ["invalid_array"]) + + def test_reshape_array_valid(self): + reshaped = ArrayUtils._reshape_array( + pd.Series([[1, 2, 3, 4]]), pd.Series([[2, 2]]) + ) + self.assertEqual(reshaped.tolist(), [[[1, 2], [3, 4]]]) + + def test_reshape_array_1D(self): + reshaped = ArrayUtils._reshape_array( + pd.Series([[1, 2, 3, 4]]), pd.Series([[4]]) + ) + self.assertEqual(reshaped.tolist(), [[1, 2, 3, 4]])