From 8173ebd396e22c05029375f593bbdb595b4bccbb Mon Sep 17 00:00:00 2001 From: Simon Albright <simon.albright@cern.ch> Date: Tue, 28 Nov 2023 14:38:23 +0100 Subject: [PATCH 1/2] clean up TrackIteration --- blond/utils/track_iteration.py | 104 +++++++++++++++++++++------------ 1 file changed, 66 insertions(+), 38 deletions(-) diff --git a/blond/utils/track_iteration.py b/blond/utils/track_iteration.py index 2b8c2e44..12853554 100644 --- a/blond/utils/track_iteration.py +++ b/blond/utils/track_iteration.py @@ -12,6 +12,20 @@ user specified functions every n turns** :Authors: **Simon Albright** ''' +# Futurisation +from __future__ import annotations + +# General imports +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from typing import Iterable, List, Callable, Protocol, Any, Self + + class Trackable(Protocol): + def track(self) -> None: + ... + class TrackIteration: @@ -22,37 +36,44 @@ class TrackIteration: Parameters ---------- - trackMap : iterable of objects + track_map : iterable of objects Each object will be called on every turn with object.track() - initTurn : integer + init_turn : integer The turn number tracking will start from, only used to initialise a turn counter - finalTurn : integer + final_turn : integer The last turn number to track next(TrackIteration) will raise - StopIteration when turnNumber == finalTurn + StopIteration when turn_number == final_turn Attributes ---------- - - functionList : List of functions to be called with specified interval + function_list : List of functions to be called with specified interval ''' - def __init__(self, trackMap, initTurn=0, finalTurn=-1): - if not all((callable(m) for m in trackMap)): - raise AttributeError("All map objects must be callable") + def __init__(self, track_map: Iterable[Trackable], init_turn: int = 0, + final_turn: int = -1): + + if not all((hasattr(m, 'track') for m in track_map)): + raise AttributeError("All map objects must be trackable") - self._map = trackMap - if isinstance(initTurn, int): - self.turnNumber = initTurn + self._map = track_map + if isinstance(init_turn, int): + self.turn_number = init_turn else: - raise TypeError("initTurn must be an integer") - if isinstance(finalTurn, int): - self._finalTurn = finalTurn + raise TypeError("init_turn must be an integer") + + if isinstance(final_turn, int): + self._final_turn = final_turn else: - raise TypeError("finalTurn must be an integer") + raise TypeError("final_turn must be an integer") + + self.function_list: List[ + Tuple[ + Callable[[Iterable[Trackable], int, ...]], + int] + ] = [] - self.functionList = [] def _track_turns(self, n_turns): ''' @@ -63,66 +84,73 @@ class TrackIteration: for i in range(n_turns): next(self) - def add_function(self, predicate, repetionRate, *args, **kwargs): + + def add_function(self, predicate: Callable[[Iterable[Trackable], int, ...]], + repetion_rate: int, *args: Any, **kwargs: Any): ''' - Takes a user defined callable and calls it every repetionRate - number of turns with predicate(trackMap, turnNumber, ``*args``, ``**kwargs``) + Takes a user defined callable and calls it every repetion_rate + number of turns with predicate(track_map, turn_number, *args, **kwargs) ''' - self.functionList.append((self._partial(predicate, args, kwargs), repetionRate)) + self.function_list.append((self._partial(predicate, *args, **kwargs), + repetion_rate)) + - def __next__(self): + def __next__(self) -> int: ''' - First raises StopIteration if turnNumber == finalTurn + First raises StopIteration if turn_number == final_turn Next calls track() from each element in trackMap list and raises StopIteration if no more turns available Finally iterates over each function specified in add_function - and calls them with predicate(trackMap, turnNumber) if - turnNumber % repetitionRate == 0 + and calls them with predicate(trackMap, turn_number) if + turn_number % repetitionRate == 0 ''' - if self.turnNumber == self._finalTurn: + if self.turn_number == self._final_turn: raise StopIteration try: for m in self._map: - m() + m.track() except IndexError: raise StopIteration - self.turnNumber += 1 + self.turn_number += 1 - for func, rate in self.functionList: - if self.turnNumber % rate == 0: - func(self._map, self.turnNumber) + for func, rate in self.function_list: + if self.turn_number % rate == 0: + func(self._map, self.turn_number) - return self.turnNumber + return self.turn_number - def __iter__(self): + + def __iter__(self) -> Self: ''' returns self ''' return self - def __call__(self, n_turns=1): + + def __call__(self, n_turns: int = 1) -> int: ''' Makes object callable with option to specify number of tracked turns default tracks 1 turn ''' self._track_turns(n_turns) - return self.turnNumber + return self.turn_number + - def _partial(self, predicate, args, kwargs): + def _partial(self, predicate: Callable, *args, **kwargs) -> Callable: ''' reimplementation of functools.partial to prepend rather than append to *args ''' - def partFunc(_map, turn): + def part_func(_map, turn): return predicate(_map, turn, *args, **kwargs) - return partFunc + return part_func -- GitLab From 0e6c7c742b5d87d652483023d3fcbdfc8aef7889 Mon Sep 17 00:00:00 2001 From: Simon Albright <simon.albright@cern.ch> Date: Mon, 4 Dec 2023 10:03:00 +0100 Subject: [PATCH 2/2] fixed TrackIteration unit tests --- unittests/utils/test_iteration.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/unittests/utils/test_iteration.py b/unittests/utils/test_iteration.py index d8e6a1e6..29c07d85 100644 --- a/unittests/utils/test_iteration.py +++ b/unittests/utils/test_iteration.py @@ -36,7 +36,7 @@ class TestTrackIteration(unittest.TestCase): self.n_turns = self.ring.n_turns - self.map_ = [self.full_ring.track, self.profile.track] + self.map_ = [self.full_ring, self.profile] self.trackIt = TrackIteration(self.map_) @@ -56,13 +56,13 @@ class TestTrackIteration(unittest.TestCase): for i in range(5): next(self.trackIt) - self.assertEqual(self.trackIt.turnNumber, 8, msg='Turn number should have incremented to 8') + self.assertEqual(self.trackIt.turn_number, 8, msg='Turn number should have incremented to 8') def test_iter(self): for i in self.trackIt: pass - self.assertEqual(self.n_turns, self.trackIt.turnNumber, msg='Iterating all turns has not incremented turnNumber correctly') + self.assertEqual(self.n_turns, self.trackIt.turn_number, msg='Iterating all turns has not incremented turnNumber correctly') def test_call(self): @@ -99,21 +99,26 @@ class TestTrackIteration(unittest.TestCase): self.trackIt(3) self.assertEqual(list1[0], 4, msg='function call should have incremented list') - self.assertEqual(list2[0], self.trackIt.turnNumber, msg='function should set list[0] to turn number') + self.assertEqual(list2[0], self.trackIt.turn_number, msg='function should set list[0] to turn number') self.assertEqual(list3[0], 8, msg='function should have been called') def test_exceptions(self): + class TestTrackable: + def track(self): + ... + testPasses = [None, 1, 'abc'] for t in testPasses: - with self.assertRaises(AttributeError, msg='Should raise AttrinuteError if non-callable object is passed in map'): + with self.assertRaises(AttributeError, msg='Should raise AttrinuteError if non-trackable object is passed in map'): TrackIteration([t]) testPasses = [None, 1., 'abc'] + for t in testPasses: with self.assertRaises(TypeError, msg='Should raise TypeError if initTurn is non-integer'): - TrackIteration([lambda _: _], t) + TrackIteration([TestTrackable()], t) with self.assertRaises(TypeError, msg='Should raise TypeError if initTurn is non-integer'): - TrackIteration([lambda _: _], 0, t) + TrackIteration([TestTrackable()], 0, t) def testItt(): for i in range(self.n_turns + 1): -- GitLab