Skip to content
Snippets Groups Projects
Commit b324999f authored by Marcel Rieger's avatar Marcel Rieger
Browse files

Update tfdeploy version.

parent b8a83802
No related branches found
No related tags found
No related merge requests found
...@@ -12,7 +12,7 @@ __credits__ = ["Marcel Rieger"] ...@@ -12,7 +12,7 @@ __credits__ = ["Marcel Rieger"]
__contact__ = "https://github.com/riga/tfdeploy" __contact__ = "https://github.com/riga/tfdeploy"
__license__ = "MIT" __license__ = "MIT"
__status__ = "Development" __status__ = "Development"
__version__ = "0.3.0" __version__ = "0.3.2"
__all__ = ["Model", "Tensor", "Operation", "Ensemble", __all__ = ["Model", "Tensor", "Operation", "Ensemble",
"UnknownOperationException", "OperationMismatchException", "UnknownOperationException", "OperationMismatchException",
...@@ -2081,7 +2081,7 @@ def Softmax(a): ...@@ -2081,7 +2081,7 @@ def Softmax(a):
# NN convolution ops # NN convolution ops
# #
def _conv_patches(a, f, strides, padding, padmode="constant"): def _prepare_patches(a, f, strides, padding, padmode):
v = np.array((0,) + (a.ndim - 2) * (1,) + (0,)) v = np.array((0,) + (a.ndim - 2) * (1,) + (0,))
w = np.array((0,) + f.shape[:-2] + (0,)) w = np.array((0,) + f.shape[:-2] + (0,))
...@@ -2097,6 +2097,12 @@ def _conv_patches(a, f, strides, padding, padmode="constant"): ...@@ -2097,6 +2097,12 @@ def _conv_patches(a, f, strides, padding, padmode="constant"):
/ strides).astype(np.int) / strides).astype(np.int)
pad = np.zeros(len(a.shape)) pad = np.zeros(len(a.shape))
return out_shape, src
def _conv_patches(a, f, strides, padding):
out_shape, src = _prepare_patches(a, f, strides, padding, "constant")
patches = np.empty(tuple(out_shape)[:-1] + f.shape).astype(a.dtype) patches = np.empty(tuple(out_shape)[:-1] + f.shape).astype(a.dtype)
s = (slice(None),) s = (slice(None),)
...@@ -2157,6 +2163,24 @@ def Conv3D(a, f, strides, padding): ...@@ -2157,6 +2163,24 @@ def Conv3D(a, f, strides, padding):
# NN pooling ops # NN pooling ops
# #
def _pool_patches(a, k, strides, padding):
f = np.ones(k[1:] + [a.shape[-1]])
out_shape, src = _prepare_patches(a, f, strides, padding, "edge")
patches = np.empty(tuple(out_shape) + f.shape).astype(a.dtype)
s = (slice(None),)
e = (Ellipsis,)
en = (Ellipsis, np.newaxis)
for coord in np.ndindex(*out_shape[1:]):
pos = np.array(strides[1:]) * coord
patches[s + coord + e] = \
src[s + tuple(slice(*tpl) for tpl in zip(pos, pos + f.shape[:-1]))][en] * f
return patches
@Operation.factory(attrs=("ksize", "strides", "padding", "data_format")) @Operation.factory(attrs=("ksize", "strides", "padding", "data_format"))
def AvgPool(a, k, strides, padding, data_format): def AvgPool(a, k, strides, padding, data_format):
""" """
...@@ -2165,8 +2189,8 @@ def AvgPool(a, k, strides, padding, data_format): ...@@ -2165,8 +2189,8 @@ def AvgPool(a, k, strides, padding, data_format):
if data_format.decode("ascii") == "NCHW": if data_format.decode("ascii") == "NCHW":
a = np.rollaxis(a, 1, -1), a = np.rollaxis(a, 1, -1),
patches = _conv_patches(a, np.ones(k[1:] + [1]), strides, padding.decode("ascii"), "edge") patches = _pool_patches(a, k, strides, padding.decode("ascii"))
pool = np.average(patches, axis=tuple(range(-len(k), -1))) pool = np.average(patches, axis=tuple(range(-len(k), 0)))
if data_format.decode("ascii") == "NCHW": if data_format.decode("ascii") == "NCHW":
pool = np.rollaxis(pool, -1, 1) pool = np.rollaxis(pool, -1, 1)
...@@ -2182,8 +2206,8 @@ def MaxPool(a, k, strides, padding, data_format): ...@@ -2182,8 +2206,8 @@ def MaxPool(a, k, strides, padding, data_format):
if data_format.decode("ascii") == "NCHW": if data_format.decode("ascii") == "NCHW":
a = np.rollaxis(a, 1, -1), a = np.rollaxis(a, 1, -1),
patches = _conv_patches(a, np.ones(k[1:] + [1]), strides, padding.decode("ascii"), "edge") patches = _pool_patches(a, k, strides, padding.decode("ascii"))
pool = np.amax(patches, axis=tuple(range(-len(k), -1))) pool = np.amax(patches, axis=tuple(range(-len(k), 0)))
if data_format.decode("ascii") == "NCHW": if data_format.decode("ascii") == "NCHW":
pool = np.rollaxis(pool, -1, 1) pool = np.rollaxis(pool, -1, 1)
...@@ -2196,8 +2220,8 @@ def AvgPool3D(a, k, strides, padding): ...@@ -2196,8 +2220,8 @@ def AvgPool3D(a, k, strides, padding):
""" """
Average 3D pooling op. Average 3D pooling op.
""" """
patches = _conv_patches(a, np.ones(k[1:] + [1]), strides, padding.decode("ascii"), "edge") patches = _pool_patches(a, k, strides, padding.decode("ascii"))
return np.average(patches, axis=tuple(range(-len(k), -1))), return np.average(patches, axis=tuple(range(-len(k), 0))),
@Operation.factory(attrs=("ksize", "strides", "padding")) @Operation.factory(attrs=("ksize", "strides", "padding"))
...@@ -2205,5 +2229,5 @@ def MaxPool3D(a, k, strides, padding): ...@@ -2205,5 +2229,5 @@ def MaxPool3D(a, k, strides, padding):
""" """
Maximum 3D pooling op. Maximum 3D pooling op.
""" """
patches = _conv_patches(a, np.ones(k[1:] + [1]), strides, padding.decode("ascii"), "edge") patches = _pool_patches(a, k, strides, padding.decode("ascii"))
return np.amax(patches, axis=tuple(range(-len(k), -1))), return np.amax(patches, axis=tuple(range(-len(k), 0))),
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment