diff --git a/python/tfdeploy.py b/python/tfdeploy.py
index 0cb2ed787441d14b5965074adfb5a9aa62a308cf..79ec925ed53fe50fbd03f4caec7d914f7fe08bdb 100644
--- a/python/tfdeploy.py
+++ b/python/tfdeploy.py
@@ -8,17 +8,17 @@ numpy.
 
 __author__     = "Marcel Rieger"
 __copyright__  = "Copyright 2016, Marcel Rieger"
-__credits__    = ["Marcel Rieger", "Benjamin Fischer"]
+__credits__    = ["Marcel Rieger"]
 __contact__    = "https://github.com/riga/tfdeploy"
 __license__    = "MIT"
 __status__     = "Development"
-__version__    = "0.2.4"
+__version__    = "0.3.0"
 
 __all__ = ["Model", "Tensor", "Operation", "Ensemble",
-           "reset", "optimize",
            "UnknownOperationException", "OperationMismatchException",
            "InvalidImplementationException", "UnknownImplementationException",
            "EnsembleMismatchException", "ScipyOperationException",
+           "reset", "optimize", "print_tensor", "print_op", "print_tf_tensor", "print_tf_op",
            "IMPL_NUMPY", "IMPL_SCIPY", "IMPLS",
            "METHOD_MEAN", "METHOD_MAX", "METHOD_MIN", "METHOD_CUSTOM", "METHODS",
            "HAS_SCIPY"]
@@ -61,9 +61,12 @@ class Model(object):
     """
     A trained model that contains one or more converted tensorflow graphs. When *path* is set, a
     previously saved model is loaded from that path. Usage:
+
     .. code-block:: python
+
        import tensorflow as tf
        import tfdeploy as td
+
        # build your graph, use names for input and output tensors
        sess = tf.Session()
        x = tf.placeholder("float", shape=[None, 784], name="input")
@@ -71,20 +74,29 @@ class Model(object):
        b = tf.Variable(tf.zeros([100]))
        y = tf.nn.softmax(tf.matmul(x, W) + b, name="output")
        sess.run(tf.initialize_all_variables())
+
        # ... training ...
+
        # create a model and save it to disk
        model = td.Model()
        model.add(y, sess)
        model.save("/path/to/model.pkl")
+
     And then in an other file:
+
     .. code-block:: python
+
        import tfdeploy as td
        import numpy as np
+
        model = td.Model("/path/to/model.pkl")
        inp, outp = model.get("input", "output")
+
        batch = np.random.rand(10000, 784)
        result = outp.eval({inp: batch})
+
     .. py:attribute:: roots
+
        Contained root tensors in a dict mapped to a key.
     """
 
@@ -190,14 +202,22 @@ class Tensor(object):
     of a graph. It contains information on the op it results from. The conversion uses the
     (tensorflow) instances *tf_tensor* and *tf_sess*, *tf_feed_dict* can be set to evaluate the
     tensor's current value.
+
     .. py:attribute:: name
+
        The name of the tensor.
+
     .. py:attribute:: value_index
+
        The integer value index of this tensor, i.e., the position in the op's output list.
+
     .. py:attribute:: op
+
        The op instance that defines the value of this tensor. When created from a
        ``tensorflow.Placeholder`` or a ``tensorflow.Variable``, op will be *None*.
+
     .. py:attribute:: value
+
        The value of this tensor. When created from a ``tensorflow.Variable``, this will be the value
        of that variable, or *None* otherwise until it is evaluated the first time.
     """
@@ -317,26 +337,41 @@ class Operation(object):
     constructor for this op's input tensors. Op instances can have multiple implementations, i.e.,
     different methods that lead to equivalent results but might use additional third-party software
     such as *scipy*. To select a specific implementation, invoke :py:func:`use_impl`:
+
     .. code-block:: python
+
        # tell SomeOp to use the scipy implementation of its op logic
        SomeOp.use_impl(IMPL_SCIPY)
+
     See :py:func:`add_impl` for more info about adding new implementations.
+
     .. py:attribute:: types
        classmember
+
        A tuple containing the types of tensorflow ops that this op can represent.
+
     .. py:attribute:: unpack
        classmember
+
        If *True* (default), the values of evaluated input tensors are forwarded to *func* as single
        arguments, or, otherwise, as a list.
+
     .. py:attribute:: attrs
        classmember
+
        Names of the configuration attributes of the original tensorflow op in a tuple.
+
     .. py:attribute:: name
+
        The name of the op.
+
     .. py:attribute:: inputs
+
        Tuple of tensors that are input to this op. Their order is important as they are forwarded to
        *func* for evaluation.
+
     .. py:attribute:: kwargs
+
        Keyword arguments containing configuration values that will be passed to *func*.
     """
 
@@ -495,12 +530,15 @@ class Operation(object):
     def add_impl(cls, impl):
         """
         Decorator to add an additional implementation to this op. Example:
+
         .. code-block:: python
+
            # initial implementation using factory, defaults to numpy
            @Operation.factory
            def MyOp(a, b):
                # use numpy only
                return ...
+
            # also add a scipy implementation
            @MyOp.add_impl(IMPL_SCIPY)
            def MyOp(a, b):
@@ -529,20 +567,29 @@ class Ensemble(object):
     An ensemble is a wrapper around multiple models to compute ensemble values. It can initialized
     with a list of model paths and an ensembling method that decides how to compute the merged
     value.
+
     .. code-block:: python
+
        # create the ensemble
        ensemble = Ensemble(["model1.pkl", "model2.pkl", ...], METHOD_MEAN)
+
        # get input and output tensors (which actually are TensorEnsemble instances)
        input, output = ensemble.get("input", "output")
+
        # evaluate the ensemble just like a normal model
        batch = ...
        value = output.eval({input: batch})
+
     If you want to use another method than ``METHOD_MEAN``, ``METHOD_MAX`` or ``METHOD_MAX``, use
     ``METHOD_CUSTOM`` and overwrite the ``func_custom`` method of the :py:class:`TensorEnsemble`
     instance.
+
     .. py:attribute:: models
+
        A list that contains all read models.
+
     .. py:attribute:: method
+
        The ensembling method.
     """
 
@@ -596,9 +643,13 @@ class TensorEnsemble(object):
     """
     A tensor ensemble basically contains a list of tensors that correspond to models of an
     :py:class:`Ensemble` instance.
+
     .. py:attribute: tensors
+
        The list of contained tensors. Tensor *i* corresponds to model *i*.
+
     .. py:attribute: method
+
        The ensembling method.
     """
 
@@ -677,35 +728,6 @@ class TensorEnsemble(object):
         raise NotImplementedError
 
 
-def reset():
-    """
-    Resets the instance caches of :py:class:`TensorRegister` and :py:class:`OperationRegister`.
-    """
-    TensorRegister.instances.clear()
-    OperationRegister.instances.clear()
-
-
-def optimize(order):
-    """ optimize(impl)
-    Tries to set the implementation type of all registered :py:class:`Operation` classes to *impl*.
-    This has no effect when an op does not implement that type.
-    The behavior is equivalent to:
-    .. code-block:: python
-       for op in Operation.__subclasses__():
-           if impl in op.impls:
-               op.use_impl(impl)
-    *impl* can also be a list or tuple of valid implementation types representing a preferred order.
-    """
-    if not isinstance(order, (list, tuple)):
-        order = [order]
-
-    for op in Operation.__subclasses__():
-        for impl in order:
-            if impl in op.impls:
-                op.use_impl(impl)
-                break
-
-
 class UnknownOperationException(Exception):
     """
     An exception which is raised when trying to convert an unknown tensorflow.
@@ -760,6 +782,142 @@ class ScipyOperationException(Exception):
         super(ScipyOperationException, self).__init__(msg)
 
 
+# parses the tf version and returns a tuple, e.g. "0.12.0-rc1" => (0, 12, 0, "rc1")
+def _parse_tf_version(v):
+    parts = v.split(".", 2)
+    if "-" in parts[2]:
+        parts.extend(parts.pop().split("-", 1))
+    return tuple([int(p) for p in parts[:3]] + parts[3:])
+
+
+# default (last) tf version
+_tf_version_string = "0.12.0-rc1"
+_tf_version = _parse_tf_version(_tf_version_string)
+
+
+def setup(tf, order=None):
+    """
+    Sets up global variables (currently only the tensorflow version) to adapt to peculiarities of
+    different tensorflow versions. This function should only be called before :py:class:`Model`
+    creation, not for evaluation. Therefore, the tensorflow module *tf* must be passed:
+
+    .. code-block:: python
+
+       import tensorflow as tf
+       import tfdeploy as td
+
+       td.setup(tf)
+
+       # ...
+
+    Also, when *order* is not *None*, it is forwarded to :py:func:`optimize` for convenience.
+    """
+    global _tf_version_string, _tf_version
+    _tf_version_string = tf.__version__
+    _tf_version = _parse_tf_version(_tf_version_string)
+
+    if order is not None:
+        optimize(order)
+
+
+def reset():
+    """
+    Resets the instance caches of :py:class:`TensorRegister` and :py:class:`OperationRegister`.
+    """
+    TensorRegister.instances.clear()
+    OperationRegister.instances.clear()
+
+
+def optimize(order):
+    """ optimize(impl)
+    Tries to set the implementation type of all registered :py:class:`Operation` classes to *impl*.
+    This has no effect when an op does not implement that type.
+
+    The behavior is equivalent to:
+
+    .. code-block:: python
+
+       for op in Operation.__subclasses__():
+           if impl in op.impls:
+               op.use_impl(impl)
+
+    *impl* can also be a list or tuple of valid implementation types representing a preferred order.
+    """
+    if not isinstance(order, (list, tuple)):
+        order = [order]
+
+    for op in Operation.__subclasses__():
+        for impl in order:
+            if impl in op.impls:
+                op.use_impl(impl)
+                break
+
+
+def print_tensor(td_tensor, indent="|   ", max_depth=-1, depth=0):
+    """ print_tensor(td_tensor, indent="    ", max_depth=-1)
+    Prints the dependency graph of a :py:class:`Tensor` *td_tensor*, where each new level is
+    indented by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where
+    each tensor and each op count as a level.
+    """
+    offset = depth * indent
+    line = "td tensor: %s" % td_tensor.name
+    if td_tensor.value is not None:
+        line += " (%s)" % (",".join(str(i) for i in td_tensor.value.shape),)
+
+    print(offset + line)
+
+    if td_tensor.op and (max_depth < 0 or max_depth > depth):
+        print_op(td_tensor.op, indent=indent, max_depth=max_depth, depth=depth+1)
+
+
+def print_op(td_op, indent="|   ", max_depth=-1, depth=0):
+    """ print_op(td_op, indent="    ", max_depth=-1)
+    Prints the dependency graph of a :py:class:`Operation` *td_op*, where each new level is indented
+    by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each
+    tensor and each op count as a level.
+    """
+    offset = depth * indent
+    line = "td op: %s (%s)" % (td_op.name, ",".join(td_op.types))
+
+    print(offset + line)
+
+    if max_depth < 0 or max_depth > depth:
+        for td_tensor in td_op.inputs:
+            print_tensor(td_tensor, indent=indent, max_depth=max_depth, depth=depth+1)
+
+
+def print_tf_tensor(tf_tensor, indent="|   ", max_depth=-1, depth=0):
+    """ print_tf_tensor(tf_tensor, indent="    ", max_depth=-1)
+    Prints the dependency graph of a tensorflow tensor *tf_tensor*, where each new level is indented
+    by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each
+    tensor and each op count as a level.
+    """
+    offset = depth * indent
+    shape = tuple(int(i) for i in tf_tensor.get_shape())
+    line = "tf tensor: %s (%s)" % (tf_tensor.name, ",".join(str(i) for i in shape))
+
+    print(offset + line)
+
+    if tf_tensor.op and (max_depth < 0 or max_depth > depth):
+        print_tf_op(tf_tensor.op, indent=indent, max_depth=max_depth, depth=depth+1)
+
+
+def print_tf_op(tf_op, indent="|   ", max_depth=-1, depth=0):
+    """ print_tf_op(tf_tensor, indent="    ", max_depth=-1)
+    Prints the dependency graph of a tensorflow operation *tf_op*, where each new level is indented
+    by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each
+    tensor and each op count as a level.
+    """
+    offset = depth * indent
+    line = "tf op: %s (%s)" % (tf_op.name, tf_op.type)
+
+    print(offset + line)
+
+    if max_depth < 0 or max_depth > depth:
+        for tf_tensor in tf_op.inputs:
+            print_tf_tensor(tf_tensor, indent=indent, max_depth=max_depth, depth=depth+1)
+
+
 # imports exclusively for ops
 from operator import mul
 from itertools import product
@@ -767,6 +925,9 @@ from collections import defaultdict
 
 # optional import of scipy
 try:
+    if os.environ.get("TD_REFUSE_SCIPY", "").lower() in ("1", "true", "yes"):
+        raise ImportError
+
     import scipy as sp
     import scipy.special
     HAS_SCIPY = True
@@ -815,15 +976,105 @@ lgamma_vec = np.vectorize(np.math.lgamma)
 erf_vec = np.vectorize(np.math.erf)
 erfc_vec = np.vectorize(np.math.erfc)
 
+def _transpose(a, dim=2):
+    if dim <= 0:
+        axes = None
+    else:
+        axes = list(range(a.ndim))
+        axes.append(axes.pop(-1 * dim))
+    return np.transpose(a, axes=axes)
+
+def _adjoint(a, dim=2):
+    return np.conj(_transpose(a, dim=dim))
+
+
+#
+# sequences
+#
 
 @Operation.factory
-def Identity(a):
+def LinSpace(start, stop, num):
     """
-    Identity op.
+    Linspace op.
     """
-    return np.copy(a),
+    return np.linspace(start, stop, num=num, dtype=np.float32),
+
+
+@Operation.factory
+def Range(start, limit, delta):
+    """
+    Range op.
+    """
+    return np.arange(start, limit, delta, dtype=np.int32),
+
+
+#
+# random tensors
+#
+
+@Operation.factory(attrs=("dtype", "seed"))
+def RandomStandardNormal(shape, dtype, seed):
+    """
+    Standard (mu=0, sigma=1) gaussian op.
+    """
+    if seed:
+        np.random.seed(seed)
+    return np.random.normal(size=reduce(mul, shape)).reshape(shape).astype(dtype_map[dtype]),
+
+
+@Operation.factory(attrs=("dtype", "seed"))
+def TruncatedNormal(shape, dtype, seed):
+    """
+    Standard (mu=0, sigma=1) gaussian op with truncation above 2 sigma.
+    """
+    if seed:
+        np.random.seed(seed)
+    n = reduce(mul, shape)
+    r = np.empty(n, dtype=dtype_map[dtype])
+    idxs = np.ones(n, dtype=np.bool)
+    while n:
+        r[idxs] = np.random.normal(size=n)
+        idxs = np.abs(r) > 2
+        n = np.sum(idxs)
+    return r.reshape(shape),
+
+
+@Operation.factory(attrs=("dtype", "seed"))
+def RandomUniform(shape, dtype, seed):
+    """
+    Random uniform op.
+    """
+    if seed:
+        np.random.seed(seed)
+    return np.random.uniform(size=shape).astype(dtype_map[dtype]),
+
+
+@Operation.factory(attrs=("seed",))
+def RandomUniformInt(shape, minval, maxval, seed):
+    """
+    Random uniform int op.
+    """
+    if seed:
+        np.random.seed(seed)
+    return np.random.randint(minval, maxval, size=shape),
+
+
+@Operation.factory(attrs=("seed",))
+def RandomShuffle(a, seed):
+    """
+    Random uniform op.
+    """
+    if seed:
+        np.random.seed(seed)
+    r = a.copy()
+    np.random.shuffle(r)
+    return r,
 
 
+#
+# casting
+#
+
 @Operation.factory(types=("Cast", "StringToNumber"), output_dtypes=True)
 def Cast(a, output_dtypes):
     """
@@ -832,6 +1083,10 @@ def Cast(a, output_dtypes):
     return np.copy(a).astype(output_dtypes[0]),
 
 
+#
+# shapes and shaping
+#
+
 @Operation.factory
 def Shape(a):
     """
@@ -889,6 +1144,10 @@ def ExpandDims(a, dim):
     return np.copy(a).reshape(*shape),
 
 
+#
+# slicing and joining
+#
+
 @Operation.factory
 def Slice(a, begin, size):
     """
@@ -985,6 +1244,10 @@ def Transpose(a, perm=None):
     return np.transpose(a, axes=perm),
 
 
+#
+# arithmetic math ops
+#
+
 @Operation.factory(types=("Add", "BiasAdd"))
 def Add(a, b):
     """
@@ -1033,6 +1296,10 @@ def Cross(a, b):
     return np.cross(a, b),
 
 
+#
+# basic math ops
+#
+
 @Operation.factory(unpack=False)
 def AddN(inputs):
     """
@@ -1177,6 +1444,38 @@ def Sin(a):
     return np.sin(a),
 
 
+@Operation.factory
+def Tan(a):
+    """
+    Tan op.
+    """
+    return np.tan(a),
+
+
+@Operation.factory
+def Acos(a):
+    """
+    Acos op.
+    """
+    return np.arccos(a),
+
+
+@Operation.factory
+def Asin(a):
+    """
+    Asin op.
+    """
+    return np.arcsin(a),
+
+
+@Operation.factory
+def Atan(a):
+    """
+    Atan op.
+    """
+    return np.arctan(a),
+
+
 @Operation.factory
 def Lgamma(a):
     """
@@ -1189,6 +1488,14 @@ def Lgamma(a):
     return sp.special.gammaln(a),
 
 
+@Operation.factory(impl=IMPL_SCIPY)
+def Digamma(a):
+    """
+    Digamma op.
+    """
+    return sp.special.digamma(a),
+
+
 @Operation.factory
 def Erf(a):
     """
@@ -1213,6 +1520,58 @@ def Erfc(a):
     return sp.special.erfc(a),
 
 
+@Operation.factory
+def SquaredDifference(a, b):
+    """
+    Squared diff op, i.e. (a-b)**2
+    """
+    return (a - b)**2,
+
+
+@Operation.factory(impl=IMPL_SCIPY)
+def Igamma(a, b):
+    """
+    Incomplete gamma op.
+    """
+    return sp.special.gammainc(a, b),
+
+
+@Operation.factory(impl=IMPL_SCIPY)
+def Igammac(a, b):
+    """
+    Complemented, incomplete gamma op.
+    """
+    return sp.special.gammaincc(a, b),
+
+
+@Operation.factory(impl=IMPL_SCIPY)
+def Zeta(a, b):
+    """
+    Zeta op.
+    """
+    return sp.special.zeta(a, b),
+
+
+@Operation.factory(impl=IMPL_SCIPY)
+def Polygamma(a, b):
+    """
+    Polygamma op.
+    """
+    return sp.special.polygamma(a, b),
+
+
+@Operation.factory(impl=IMPL_SCIPY)
+def Betainc(a, b, x):
+    """
+    Complemented, incomplete gamma op.
+    """
+    return sp.special.betainc(a, b, x),
+
+
+#
+# matrix math ops
+#
+
 @Operation.factory
 def Diag(a):
     """
@@ -1224,6 +1583,26 @@ def Diag(a):
     return r,
 
 
+@Operation.factory
+def DiagPart(a):
+    """
+    Diag op that returns only the diagonal elements.
+    """
+    return np.diagonal(a),
+
+
+@Operation.factory
+def MatrixDiagPart(a):
+    """
+    Batched diag op that returns only the diagonal elements.
+    """
+    r = np.zeros(a.shape[:-2] + (min(a.shape[-2:]),))
+    for coord in np.ndindex(a.shape[:-2]):
+        pos = coord + (Ellipsis,)
+        r[pos] = np.diagonal(a[pos])
+    return r,
+
+
 @Operation.factory(attrs=("transpose_a", "transpose_b"))
 def MatMul(a, b, transpose_a, transpose_b):
     """
@@ -1239,12 +1618,10 @@ def BatchMatMul(a, b, adj_a, adj_b):
     Batched matrix multiplication op.
     """
     # apply adjoint op if required along last two axes
-    axes = list(range(len(a.shape)))
-    axes.append(axes.pop(-2))
     if adj_a:
-        a = np.conj(np.transpose(a, axes=axes))
+        a = _adjoint(a)
     if adj_b:
-        b = np.conj(np.transpose(b, axes=axes))
+        b = _adjoint(b)
     # create the target tensor
     r = np.empty(a.shape[:-2] + (a.shape[-2], b.shape[-1]))
     # no batched dot op in np, so loop over all indexes except last two dims
@@ -1253,7 +1630,7 @@ def BatchMatMul(a, b, adj_a, adj_b):
     return r,
 
 
-@Operation.factory(types=("MatrixDeterminant", "BatchMatrixDeterminant"))
+@Operation.factory
 def MatrixDeterminant(a):
     """
     Matrix det op.
@@ -1261,15 +1638,15 @@ def MatrixDeterminant(a):
     return np.linalg.det(a),
 
 
-@Operation.factory(types=("MatrixInverse", "BatchMatrixInverse"))
-def MatrixInverse(a):
+@Operation.factory(attrs=("adjoint",))
+def MatrixInverse(a, adj):
     """
     Matrix inversion op.
     """
-    return np.linalg.inv(a),
+    return np.linalg.inv(a if not adj else _adjoint(a)),
 
 
-@Operation.factory(types=("Cholesky", "BatchCholesky"))
+@Operation.factory
 def Cholesky(a):
     """
     Cholesky decomposition op.
@@ -1277,7 +1654,44 @@ def Cholesky(a):
     return np.linalg.cholesky(a),
 
 
-@Operation.factory(types=("SelfAdjointEig", "BatchSelfAdjointEig"))
+@Operation.factory(attrs=("adjoint",))
+def MatrixSolve(a, rhs, adj):
+    """
+    Matrix solve op.
+    """
+    return np.linalg.solve(a if not adj else _adjoint(a), rhs),
+
+
+@Operation.factory(attrs=("lower", "adjoint"), impl=IMPL_SCIPY)
+def MatrixTriangularSolve(a, rhs, lower, adj):
+    """
+    Matrix triangular solve op.
+    """
+    trans = 0 if not adj else 2
+
+    r = np.empty(rhs.shape).astype(a.dtype)
+    for coord in np.ndindex(a.shape[:-2]):
+        pos = coord + (Ellipsis,)
+        r[pos] = sp.linalg.solve_triangular(a[pos] if not adj else np.conj(a[pos]), rhs[pos],
+                                            trans=trans, lower=lower)
+
+    return r,
+
+
+@Operation.factory
+def MatrixSolveLs(a, rhs, l2_reg):
+    """
+    Matrix least-squares solve op.
+    """
+    r = np.empty(rhs.shape).astype(a.dtype)
+    for coord in np.ndindex(a.shape[:-2]):
+        pos = coord + (Ellipsis,)
+        r[pos] = np.linalg.lstsq(a[pos], rhs[pos])[0]
+
+    return r,
+
+
+@Operation.factory
 def SelfAdjointEig(a):
     """
     Eigen decomp op.
@@ -1287,22 +1701,27 @@ def SelfAdjointEig(a):
     return np.append(*np.linalg.eig(a)).reshape(*shape),
 
 
-@Operation.factory(types=("MatrixSolve", "BatchMatrixSolve"))
-def MatrixSolve(a, b):
+@Operation.factory
+def SelfAdjointEigV2(a):
     """
-    Matrix solve op.
+    Eigen decomp op.
     """
-    return np.linalg.solve(a, b),
+    return np.linalg.eig(a)
 
 
-@Operation.factory
-def MatrixSolveLs(a, b, l2_regularizer):
+@Operation.factory(attrs=("compute_uv", "full_matrices"))
+def Svd(a, uv, full):
     """
-    Matrix least-squares solve op.
+    Single value decomp op.
     """
-    return np.linalg.lstsq(a, b)[0],
+    u, s, v = np.linalg.svd(a, full_matrices=full, compute_uv=uv)
+    return s, u, v
 
 
+#
+# complex number ops
+#
+
 @Operation.factory
 def Complex(a, b):
     """
@@ -1343,6 +1762,10 @@ def Real(a):
     return np.real(a),
 
 
+#
+# Fourier transform ops
+#
+
 @Operation.factory
 def FFT2D(a):
     """
@@ -1359,6 +1782,26 @@ def IFFT2D(a):
     return np.fft.ifft2(a),
 
 
+@Operation.factory
+def FFT3D(a):
+    """
+    Discrete 3D FT op.
+    """
+    return np.fft.fftn(a),
+
+
+@Operation.factory
+def IFFT3D(a):
+    """
+    Discrete inverse 3D FT op.
+    """
+    return np.fft.ifftn(a),
+
+
+#
+# reduction
+#
+
 @Operation.factory(attrs=("keep_dims",))
 def Sum(a, reduction_indices, keep_dims):
     """
@@ -1415,6 +1858,10 @@ def Any(a, reduction_indices, keep_dims):
     return np.any(a, axis=tuple(reduction_indices), keepdims=keep_dims),
 
 
+#
+# segmentation
+#
+
 def seg_map(func, a, ids):
     m = defaultdict(list)
     for i, e in enumerate(ids):
@@ -1495,6 +1942,10 @@ def SparseSegmentSqrtN(a, idxs, ids):
     return seg_map(func, a, ids),
 
 
+#
+# sequence comparison and indexing
+#
+
 @Operation.factory
 def ArgMin(a, dim):
     """
@@ -1528,13 +1979,13 @@ def Where(a):
     return np.argwhere(a),
 
 
-@Operation.factory
-def Unique(a):
+@Operation.factory(attrs=("out_idx",))
+def Unique(a, t):
     """
     Unique op.
     """
     _, idxs, inv = np.unique(a, return_index=True, return_inverse=True)
-    return np.copy(a)[np.sort(idxs)], idxs[inv].astype(np.int32)
+    return np.copy(a)[np.sort(idxs)], idxs[inv].astype(dtype_map[t])
 
 
 @Operation.factory
@@ -1545,80 +1996,21 @@ def InvertPermutation(a):
     return np.argsort(a).astype(np.int32),
 
 
-@Operation.factory
-def LinSpace(start, stop, num):
-    """
-    Linspace op.
-    """
-    return np.linspace(start, stop, num=num, dtype=np.float32),
-
+#
+# control flow ops
+#
 
 @Operation.factory
-def Range(start, limit, delta):
-    """
-    Range op.
-    """
-    return np.arange(start, limit, delta, dtype=np.int32),
-
-
-@Operation.factory(attrs=("dtype", "seed"))
-def RandomStandardNormal(shape, dtype, seed):
-    """
-    Standard (mu=0, sigma=1) gaussian op.
-    """
-    if seed:
-        np.random.seed(seed)
-    return np.random.normal(size=reduce(mul, shape)).reshape(shape).astype(dtype_map[dtype]),
-
-
-@Operation.factory(attrs=("dtype", "seed"))
-def TruncatedNormal(shape, dtype, seed):
-    """
-    Standard (mu=0, sigma=1) gaussian op with truncation above 2 sigma.
-    """
-    if seed:
-        np.random.seed(seed)
-    n = reduce(mul, shape)
-    r = np.empty(n, dtype=dtype_map[dtype])
-    idxs = np.ones(n, dtype=np.bool)
-    while n:
-        r[idxs] = np.random.normal(size=n)
-        idxs = np.abs(r) > 2
-        n = np.sum(idxs)
-    return r.reshape(shape),
-
-
-@Operation.factory(attrs=("dtype", "seed"))
-def RandomUniform(shape, dtype, seed):
-    """
-    Random uniform op.
-    """
-    if seed:
-        np.random.seed(seed)
-    return np.random.uniform(size=shape).astype(dtype_map[dtype]),
-
-
-@Operation.factory(attrs=("seed",))
-def RandomUniformInt(shape, minval, maxval, seed):
+def Identity(a):
     """
-    Random uniform int op.
+    Identity op.
     """
-    if seed:
-        np.random.seed(seed)
-    return np.random.randint(minval, maxval, size=shape),
-
+    return np.copy(a),
 
-@Operation.factory(attrs=("seed",))
-def RandomShuffle(a, seed):
-    """
-    Random uniform op.
-    """
-    if seed:
-        np.random.seed(seed)
-    r = a.copy()
-    np.random.shuffle(r)
-    return r,
 
+#
+# NN activation ops
+#
 
 @Operation.factory
 def Relu(a):
@@ -1683,3 +2075,135 @@ def Softmax(a):
     """
     e = np.exp(a)
     return np.divide(e, np.sum(e, axis=-1, keepdims=True)),
+
+
+#
+# NN convolution ops
+#
+
+def _conv_patches(a, f, strides, padding, padmode="constant"):
+    v = np.array((0,) + (a.ndim - 2) * (1,) + (0,))
+    w = np.array((0,) + f.shape[:-2] + (0,))
+
+    src = a
+    if padding == "SAME":
+        out_shape = np.ceil(np.array(a.shape).astype(np.float) / strides).astype(np.int)
+        pad = ((out_shape - v) * strides + w - a.shape).clip(min=0)
+        pad_start = pad // 2
+        if np.any(pad):
+            src = np.pad(a, list(zip(pad_start, pad - pad_start)), padmode)
+    else: # VALID
+        out_shape = np.ceil((np.array(a.shape).astype(np.float) - w + v) \
+                            / strides).astype(np.int)
+        pad = np.zeros(len(a.shape))
+
+    patches = np.empty(tuple(out_shape)[:-1] + f.shape).astype(a.dtype)
+
+    s = (slice(None),)
+    e = (Ellipsis,)
+    en = (Ellipsis, np.newaxis)
+    for coord in np.ndindex(*out_shape[1:-1]):
+        pos = np.array(strides[1:-1]) * coord
+        patches[s + coord + e] = \
+            src[s + tuple(slice(*tpl) for tpl in zip(pos, pos + f.shape[:-2]))][en] * f
+
+    return patches
+
+
+@Operation.factory(attrs=("strides", "padding", "data_format"))
+def Conv1D(a, f, strides, padding, data_format):
+    """
+    1D conv op.
+    """
+    if data_format.decode("ascii") == "NCHW":
+        a = np.rollaxis(a, 1, -1),
+
+    patches = _conv_patches(a, f, 3 * [strides], padding.decode("ascii"))
+    conv = np.sum(patches, axis=tuple(range(-f.ndim, -1)))
+
+    if data_format.decode("ascii") == "NCHW":
+        conv = np.rollaxis(conv, -1, 1)
+
+    return conv,
+
+
+@Operation.factory(attrs=("strides", "padding", "data_format"))
+def Conv2D(a, f, strides, padding, data_format):
+    """
+    2D conv op.
+    """
+    if data_format.decode("ascii") == "NCHW":
+        a = np.rollaxis(a, 1, -1),
+
+    patches = _conv_patches(a, f, strides, padding.decode("ascii"))
+    conv = np.sum(patches, axis=tuple(range(-f.ndim, -1)))
+
+    if data_format.decode("ascii") == "NCHW":
+        conv = np.rollaxis(conv, -1, 1)
+
+    return conv,
+
+
+@Operation.factory(attrs=("strides", "padding"))
+def Conv3D(a, f, strides, padding):
+    """
+    3D conv op.
+    """
+    patches = _conv_patches(a, f, strides, padding.decode("ascii"))
+    return np.sum(patches, axis=tuple(range(-f.ndim, -1))),
+
+
+#
+# NN pooling ops
+#
+
+@Operation.factory(attrs=("ksize", "strides", "padding", "data_format"))
+def AvgPool(a, k, strides, padding, data_format):
+    """
+    Average pooling op.
+    """
+    if data_format.decode("ascii") == "NCHW":
+        a = np.rollaxis(a, 1, -1),
+
+    patches = _conv_patches(a, np.ones(k[1:] + [1]), strides, padding.decode("ascii"), "edge")
+    pool = np.average(patches, axis=tuple(range(-len(k), -1)))
+
+    if data_format.decode("ascii") == "NCHW":
+        pool = np.rollaxis(pool, -1, 1)
+
+    return pool,
+
+
+@Operation.factory(attrs=("ksize", "strides", "padding", "data_format"))
+def MaxPool(a, k, strides, padding, data_format):
+    """
+    Maximum pooling op.
+    """
+    if data_format.decode("ascii") == "NCHW":
+        a = np.rollaxis(a, 1, -1),
+
+    patches = _conv_patches(a, np.ones(k[1:] + [1]), strides, padding.decode("ascii"), "edge")
+    pool = np.amax(patches, axis=tuple(range(-len(k), -1)))
+
+    if data_format.decode("ascii") == "NCHW":
+        pool = np.rollaxis(pool, -1, 1)
+
+    return pool,
+
+
+@Operation.factory(attrs=("ksize", "strides", "padding"))
+def AvgPool3D(a, k, strides, padding):
+    """
+    Average 3D pooling op.
+    """
+    patches = _conv_patches(a, np.ones(k[1:] + [1]), strides, padding.decode("ascii"), "edge")
+    return np.average(patches, axis=tuple(range(-len(k), -1))),
+
+
+@Operation.factory(attrs=("ksize", "strides", "padding"))
+def MaxPool3D(a, k, strides, padding):
+    """
+    Maximum 3D pooling op.
+    """
+    patches = _conv_patches(a, np.ones(k[1:] + [1]), strides, padding.decode("ascii"), "edge")
+    return np.amax(patches, axis=tuple(range(-len(k), -1))),