From 8173ebd396e22c05029375f593bbdb595b4bccbb Mon Sep 17 00:00:00 2001
From: Simon Albright <simon.albright@cern.ch>
Date: Tue, 28 Nov 2023 14:38:23 +0100
Subject: [PATCH 1/2] clean up TrackIteration

---
 blond/utils/track_iteration.py | 104 +++++++++++++++++++++------------
 1 file changed, 66 insertions(+), 38 deletions(-)

diff --git a/blond/utils/track_iteration.py b/blond/utils/track_iteration.py
index 2b8c2e44..12853554 100644
--- a/blond/utils/track_iteration.py
+++ b/blond/utils/track_iteration.py
@@ -12,6 +12,20 @@ user specified functions every n turns**
 
 :Authors: **Simon Albright**
 '''
+# Futurisation
+from __future__ import annotations
+
+# General imports
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+    from typing import Iterable, List, Callable, Protocol, Any, Self
+
+    class Trackable(Protocol):
+        def track(self) -> None:
+            ...
+
 
 
 class TrackIteration:
@@ -22,37 +36,44 @@ class TrackIteration:
     Parameters
     ----------
 
-    trackMap : iterable of objects
+    track_map : iterable of objects
         Each object will be called on every turn with object.track()
-    initTurn : integer
+    init_turn : integer
         The turn number tracking will start from, only used to initialise
         a turn counter
-    finalTurn : integer
+    final_turn : integer
         The last turn number to track next(TrackIteration) will raise
-        StopIteration when turnNumber == finalTurn
+        StopIteration when turn_number == final_turn
 
     Attributes
     ----------
-
-    functionList : List of functions to be called with specified interval
+    function_list : List of functions to be called with specified interval
     '''
 
-    def __init__(self, trackMap, initTurn=0, finalTurn=-1):
 
-        if not all((callable(m) for m in trackMap)):
-            raise AttributeError("All map objects must be callable")
+    def __init__(self, track_map: Iterable[Trackable], init_turn: int = 0,
+                 final_turn: int = -1):
+
+        if not all((hasattr(m, 'track') for m in track_map)):
+            raise AttributeError("All map objects must be trackable")
 
-        self._map = trackMap
-        if isinstance(initTurn, int):
-            self.turnNumber = initTurn
+        self._map = track_map
+        if isinstance(init_turn, int):
+            self.turn_number = init_turn
         else:
-            raise TypeError("initTurn must be an integer")
-        if isinstance(finalTurn, int):
-            self._finalTurn = finalTurn
+            raise TypeError("init_turn must be an integer")
+
+        if isinstance(final_turn, int):
+            self._final_turn = final_turn
         else:
-            raise TypeError("finalTurn must be an integer")
+            raise TypeError("final_turn must be an integer")
+
+        self.function_list: List[
+                            Tuple[
+                            Callable[[Iterable[Trackable], int, ...]],
+                            int]
+                            ] = []
 
-        self.functionList = []
 
     def _track_turns(self, n_turns):
         '''
@@ -63,66 +84,73 @@ class TrackIteration:
         for i in range(n_turns):
             next(self)
 
-    def add_function(self, predicate, repetionRate, *args, **kwargs):
+
+    def add_function(self, predicate: Callable[[Iterable[Trackable], int, ...]],
+                     repetion_rate: int, *args: Any, **kwargs: Any):
         '''
-        Takes a user defined callable and calls it every repetionRate
-        number of turns with predicate(trackMap, turnNumber, ``*args``, ``**kwargs``)
+        Takes a user defined callable and calls it every repetion_rate
+        number of turns with predicate(track_map, turn_number, *args, **kwargs)
         '''
 
-        self.functionList.append((self._partial(predicate, args, kwargs), repetionRate))
+        self.function_list.append((self._partial(predicate, *args, **kwargs),
+                                  repetion_rate))
+
 
-    def __next__(self):
+    def __next__(self) -> int:
         '''
-        First raises StopIteration if turnNumber == finalTurn
+        First raises StopIteration if turn_number == final_turn
 
         Next calls track() from each element in trackMap list and raises
         StopIteration if no more turns available
 
         Finally iterates over each function specified in add_function
-        and calls them with predicate(trackMap, turnNumber) if
-        turnNumber % repetitionRate == 0
+        and calls them with predicate(trackMap, turn_number) if
+        turn_number % repetitionRate == 0
         '''
 
-        if self.turnNumber == self._finalTurn:
+        if self.turn_number == self._final_turn:
             raise StopIteration
 
         try:
             for m in self._map:
-                m()
+                m.track()
         except IndexError:
             raise StopIteration
 
-        self.turnNumber += 1
+        self.turn_number += 1
 
-        for func, rate in self.functionList:
-            if self.turnNumber % rate == 0:
-                func(self._map, self.turnNumber)
+        for func, rate in self.function_list:
+            if self.turn_number % rate == 0:
+                func(self._map, self.turn_number)
 
-        return self.turnNumber
+        return self.turn_number
 
-    def __iter__(self):
+
+    def __iter__(self) -> Self:
         '''
         returns self
         '''
 
         return self
 
-    def __call__(self, n_turns=1):
+
+    def __call__(self, n_turns: int = 1) -> int:
         '''
         Makes object callable with option to specify number of tracked turns
         default tracks 1 turn
         '''
 
         self._track_turns(n_turns)
-        return self.turnNumber
+        return self.turn_number
+
 
-    def _partial(self, predicate, args, kwargs):
+    def _partial(self, predicate: Callable, *args, **kwargs) -> Callable:
         '''
         reimplementation of functools.partial to prepend
         rather than append to *args
         '''
 
-        def partFunc(_map, turn):
+        def part_func(_map, turn):
             return predicate(_map, turn, *args, **kwargs)
 
-        return partFunc
+        return part_func
-- 
GitLab


From 0e6c7c742b5d87d652483023d3fcbdfc8aef7889 Mon Sep 17 00:00:00 2001
From: Simon Albright <simon.albright@cern.ch>
Date: Mon, 4 Dec 2023 10:03:00 +0100
Subject: [PATCH 2/2] fixed TrackIteration unit tests

---
 unittests/utils/test_iteration.py | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/unittests/utils/test_iteration.py b/unittests/utils/test_iteration.py
index d8e6a1e6..29c07d85 100644
--- a/unittests/utils/test_iteration.py
+++ b/unittests/utils/test_iteration.py
@@ -36,7 +36,7 @@ class TestTrackIteration(unittest.TestCase):
 
         self.n_turns = self.ring.n_turns
 
-        self.map_ = [self.full_ring.track, self.profile.track]
+        self.map_ = [self.full_ring, self.profile]
 
         self.trackIt = TrackIteration(self.map_)
 
@@ -56,13 +56,13 @@ class TestTrackIteration(unittest.TestCase):
         for i in range(5):
             next(self.trackIt)
 
-        self.assertEqual(self.trackIt.turnNumber, 8, msg='Turn number should have incremented to 8')
+        self.assertEqual(self.trackIt.turn_number, 8, msg='Turn number should have incremented to 8')
 
     def test_iter(self):
 
         for i in self.trackIt:
             pass
-        self.assertEqual(self.n_turns, self.trackIt.turnNumber, msg='Iterating all turns has not incremented turnNumber correctly')
+        self.assertEqual(self.n_turns, self.trackIt.turn_number, msg='Iterating all turns has not incremented turnNumber correctly')
 
     def test_call(self):
 
@@ -99,21 +99,26 @@ class TestTrackIteration(unittest.TestCase):
         self.trackIt(3)
 
         self.assertEqual(list1[0], 4, msg='function call should have incremented list')
-        self.assertEqual(list2[0], self.trackIt.turnNumber, msg='function should set list[0] to turn number')
+        self.assertEqual(list2[0], self.trackIt.turn_number, msg='function should set list[0] to turn number')
         self.assertEqual(list3[0], 8, msg='function should have been called')
 
     def test_exceptions(self):
 
+        class TestTrackable:
+            def track(self):
+                ...
+
         testPasses = [None, 1, 'abc']
         for t in testPasses:
-            with self.assertRaises(AttributeError, msg='Should raise AttrinuteError if non-callable object is passed in map'):
+            with self.assertRaises(AttributeError, msg='Should raise AttrinuteError if non-trackable object is passed in map'):
                 TrackIteration([t])
         testPasses = [None, 1., 'abc']
+
         for t in testPasses:
             with self.assertRaises(TypeError, msg='Should raise TypeError if initTurn is non-integer'):
-                TrackIteration([lambda _: _], t)
+                TrackIteration([TestTrackable()], t)
             with self.assertRaises(TypeError, msg='Should raise TypeError if initTurn is non-integer'):
-                TrackIteration([lambda _: _], 0, t)
+                TrackIteration([TestTrackable()], 0, t)
 
         def testItt():
             for i in range(self.n_turns + 1):
-- 
GitLab