Commit be5cae67 authored by Rosen Matev's avatar Rosen Matev
Browse files

Merge branch 'NN_fixLoI' into 'master'

Fix list of inputs in case of datahandles and list of inputs for tools

Closes #44 and #36

See merge request lhcb/Moore!201
parents 5e214c0c 107ac5e9
......@@ -30,7 +30,7 @@ import re
from GaudiKernel.ConfigurableMeta import ConfigurableMeta
from . import ConfigurationError, configurable
from .dataflow import DataHandle, configurable_outputs, configurable_inputs, dataflow_config, is_datahandle
from .dataflow import DataHandle, configurable_outputs, configurable_inputs, dataflow_config, contains_datahandle
from .utilities import graphviz_module
__all__ = [
......@@ -39,6 +39,7 @@ __all__ = [
'force_location',
'make_algorithm',
'make_tool',
'contains_algorithm',
'is_algorithm',
'is_tool',
'setup_component',
......@@ -147,8 +148,13 @@ def _gather_tool_names(tool_dict):
def is_algorithm(arg):
"""Returns True if arg is of type Algorithm"""
return isinstance(arg, Algorithm)
def contains_algorithm(arg):
"""Return True if arg is an Algorithm instance or list of Algorithm instances."""
return isinstance(arg, Algorithm) or _is_list_of_algs(arg)
return is_algorithm(arg) or _is_list_of_algs(arg)
def _is_list_of_algs(iterable):
......@@ -156,8 +162,8 @@ def _is_list_of_algs(iterable):
Returns False if the iterable is empty.
"""
return False if not iterable else (isinstance(iterable, list)
and all(map(is_algorithm, iterable)))
return False if not iterable else (
isinstance(iterable, list) and all(map(contains_algorithm, iterable)))
def is_tool(arg):
......@@ -167,7 +173,7 @@ def is_tool(arg):
def _is_input(arg):
"""Return True if arg is something that produces output."""
return is_datahandle(arg) or is_algorithm(arg)
return contains_datahandle(arg) or contains_algorithm(arg)
def _get_input(arg):
......@@ -448,8 +454,8 @@ class Algorithm(object):
props_hash = _hash_dict(props)
# TODO include the transformed input somehow
inputs_hash = _hash_dict(
{key: _datahandle_ids(handle)
for key, handle in inputs.items()})
{key: _datahandle_ids(handles)
for key, handles in inputs.items()})
tools_hash = _hash_dict({key: tool.id for key, tool in tools.items()})
outputs_hash = _hash_list(forced_outputs)
to_be_hashed = [
......@@ -755,7 +761,7 @@ class Tool(object):
instance = super(Tool, cls).__new__(cls)
instance._id = identity
instance._parent = None
instance._tooltype = tool_type
instance._tool_type = tool_type
instance._name = _get_unique_name(name or tool_type.getType())
instance._private = False
instance._inputs = _inputs
......@@ -771,8 +777,8 @@ class Tool(object):
def _calc_id(typename, props, inputs, tools, parent=None):
props_hash = _hash_dict(props)
inputs_hash = _hash_dict(
{key: handle.id
for key, handle in inputs.items()})
{key: _datahandle_ids(handles)
for key, handles in inputs.items()})
tools_hash = _hash_dict({key: tool.id for key, tool in tools.items()})
to_be_hashed = [typename, props_hash, inputs_hash, tools_hash]
if parent:
......@@ -816,7 +822,7 @@ class Tool(object):
@property
def type(self):
return self._tooltype
return self._tool_type
@property
def typename(self):
......@@ -833,7 +839,7 @@ class Tool(object):
self._parent = parent
self._private = True
del self._tool_store[self._id]
self._id = self._calc_id(self._tooltype, self._properties,
self._id = self._calc_id(self._tool_type, self._properties,
self._inputs, self._tools, parent)
self._tool_store[self._id] = self
object.__setattr__(self, "_readonly", True) # make immutable again
......@@ -854,8 +860,10 @@ class Tool(object):
self))
config = dataflow_config()
for i in self.inputs.values():
config.update(i.producer.configuration())
for inputs in self.inputs.values():
inputs = inputs if isinstance(inputs, list) else [inputs]
for inp in inputs:
config.update(inp.producer.configuration())
for tool in self.tools.values():
config.update(tool.configuration())
......
......@@ -148,9 +148,24 @@ class DataHandle(object):
def is_datahandle(arg):
"""Returns True if arg is of type DataHandle"""
return isinstance(arg, DataHandle)
def contains_datahandle(arg):
"""Return True if arg is a DataHandle instance or list of DataHandle instances."""
return is_datahandle(arg) or _is_list_of_datahandles(arg)
def _is_list_of_datahandles(iterable):
"""Return True if all elements are DataHandle instances.
Returns False if the iterable is empty.
"""
return False if not iterable else (
isinstance(iterable, list) and all(map(contains_datahandle, iterable)))
def is_datahandle_writer(x):
"""Return True if x is a writer DataHandle or a list of them.
......
......@@ -84,16 +84,19 @@ def test_init_input_location_list():
vdp1 = Algorithm(VectorDataProducer)
vdp2 = Algorithm(VectorDataProducer)
multitransformer = Algorithm(
IntVectorsToIntVector, InputLocations=[vdp1, vdp2])
mt = Algorithm(IntVectorsToIntVector, InputLocations=[vdp1, vdp2])
mt2 = Algorithm(
IntVectorsToIntVector,
InputLocations=[vdp1.OutputLocation, vdp2.OutputLocation])
# If we initialize the MultiTransformer with inputs, then we can detect the
# DataHandles and everything looks sensible (wrt. the FIXME above)
assert len(multitransformer.inputs) == 1
assert len(multitransformer.inputs['InputLocations']) == 2
assert all(
isinstance(x, DataHandle)
for x in multitransformer.inputs['InputLocations'])
assert len(mt.inputs) == 1
assert len(mt2.inputs) == 1
assert len(mt.inputs['InputLocations']) == 2
assert len(mt2.inputs['InputLocations']) == 2
assert all(isinstance(x, DataHandle) for x in mt.inputs['InputLocations'])
assert all(isinstance(x, DataHandle) for x in mt2.inputs['InputLocations'])
def test_tool_compat():
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment