# Source code for GPy.kern.src.multioutput_kern

# Copyright (c) 2018, GPy authors (see AUTHORS.txt).

from .kern import Kern, CombinationKernel
import numpy as np
from functools import reduce, partial
from ...util.multioutput import index_to_slices
from paramz.caching import Cache_this

[docs]class ZeroKern(Kern):
def __init__(self):
super(ZeroKern, self).__init__(1, None, name='ZeroKern',useGPU=False)

[docs]    def K(self, X ,X2=None):
if X2 is None:
X2 = X
return np.zeros((X.shape[0],X2.shape[0]))

return np.zeros(dL_dK.shape)

return np.zeros((X.shape[0],X.shape[1]))
@property
return np.empty((1,0))

pass

[docs]class MultioutputKern(CombinationKernel):
"""
Multioutput kernel is a meta class for combining different kernels for multioutput GPs.

As an example let us have inputs x1 for output 1 with covariance k1 and x2 for output 2 with covariance k2.
In addition, we need to define the cross covariances k12(x1,x2) and k21(x2,x1). Then the kernel becomes:
k([x1,x2],[x1,x2]) = [k1(x1,x1) k12(x1, x2); k21(x2, x1), k2(x2,x2)]

For  the kernel, the kernels of outputs are given as list in param "kernels" and cross covariances are
given in param "cross_covariances" as a dictionary of tuples (i,j) as keys. If no cross covariance is given,
it defaults to zero, as in k12(x1,x2)=0.

In the cross covariance dictionary, the value needs to be a struct with elements
-'kernel': a member of Kernel class that stores the hyper parameters to be updated when optimizing the GP
-'K': function defining the cross covariance
-'gradients_X': gives a gradient of the cross covariance with respect to the first input
"""
def __init__(self, kernels, cross_covariances={}, name='MultioutputKern'):
#kernels contains a list of kernels as input,
if not isinstance(kernels, list):
self.single_kern = True
self.kern = kernels
kernels = [kernels]
else:
self.single_kern = False
self.kern = kernels

# The combination kernel ALLWAYS puts the extra dimension last.
# Thus, the index dimension of this kernel is always the last dimension
# after slicing. This is why the index_dim is just the last column:
self.index_dim = -1

nl = len(kernels)
#build covariance structure
covariance = [[None for i in range(nl)] for j in range(nl)]
for i in range(0,nl):
unique=True
for j in range(0,nl):
if i==j or (kernels[i] is kernels[j]):
covariance[i][j] = kernels[i]
if i>j:
unique=False
elif cross_covariances.get((i,j)) is not None: #cross covariance is given
covariance[i][j] = cross_covariances.get((i,j))
else: # zero covariance structure
covariance[i][j] = ZeroKern()
if unique is True:
self.covariance = covariance

[docs]    @Cache_this(limit=3, ignore_args=())
def K(self, X ,X2=None):
if X2 is None:
X2 = X
slices = index_to_slices(X[:,self.index_dim])
slices2 = index_to_slices(X2[:,self.index_dim])
target =  np.zeros((X.shape[0], X2.shape[0]))
[[[[ target.__setitem__((slices[i][k],slices2[j][l]), self.covariance[i][j].K(X[slices[i][k],:],X2[slices2[j][l],:])) for k in range( len(slices[i]))] for l in range(len(slices2[j])) ] for i in range(len(slices))] for j in range(len(slices2))]
return target

[docs]    @Cache_this(limit=3, ignore_args=())
def Kdiag(self,X):
slices = index_to_slices(X[:,self.index_dim])
kerns = itertools.repeat(self.kern) if self.single_kern else self.kern
target = np.zeros(X.shape[0])
[[np.copyto(target[s], kern.Kdiag(X[s])) for s in slices_i] for kern, slices_i in zip(kerns, slices)]
return target

def _update_gradients_full_wrapper(self, kern, dL_dK, X, X2):

slices = index_to_slices(X[:,self.index_dim])
if X2 is not None:
slices2 = index_to_slices(X2[:,self.index_dim])
[[[[ self._update_gradients_full_wrapper(self.covariance[i][j], dL_dK[slices[i][k],slices2[j][l]], X[slices[i][k],:], X2[slices2[j][l],:]) for k in range(len(slices[i]))] for l in range(len(slices2[j]))] for i in range(len(slices))] for j in range(len(slices2))]
else:
[[[[ self._update_gradients_full_wrapper(self.covariance[i][j], dL_dK[slices[i][k],slices[j][l]], X[slices[i][k],:], X[slices[j][l],:]) for k in range(len(slices[i]))] for l in range(len(slices[j]))] for i in range(len(slices))] for j in range(len(slices))]