Source code for GPy.plotting.matplot_dep.img_plots

# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
"""
The module contains the tools for ploting 2D image visualizations
"""

import numpy as np
from matplotlib.cm import jet

width_max = 15
height_max = 12

def _calculateFigureSize(x_size, y_size, fig_ncols, fig_nrows, pad):
    width = (x_size*fig_ncols+pad*(fig_ncols-1))
    height = (y_size*fig_nrows+pad*(fig_nrows-1))
    if width > float(height)/height_max*width_max:
        return (width_max, float(width_max)/width*height)
    else:
        return (float(height_max)/height*width, height_max)

[docs]def plot_2D_images(figure, arr, symmetric=False, pad=None, zoom=None, mode=None, interpolation='nearest'): ax = figure.add_subplot(111) if len(arr.shape)==2: arr = arr.reshape(*((1,)+arr.shape)) fig_num = arr.shape[0] y_size = arr.shape[1] x_size = arr.shape[2] fig_ncols = int(np.ceil(np.sqrt(fig_num))) fig_nrows = int(np.ceil((float)(fig_num)/fig_ncols)) if pad==None: pad = max(int(min(y_size,x_size)/10),1) figsize = _calculateFigureSize(x_size, y_size, fig_ncols, fig_nrows, pad) #figure.set_size_inches(figsize,forward=True) #figure.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) if symmetric: # symmetric around zero: fix zero as the middle color mval = max(abs(arr.max()),abs(arr.min())) arr = arr/(2.*mval)+0.5 else: minval,maxval = arr.min(),arr.max() arr = (arr-minval)/(maxval-minval) if mode=='L': arr_color = np.empty(arr.shape+(3,)) arr_color[:] = arr.reshape(*(arr.shape+(1,))) elif mode==None or mode=='jet': arr_color = jet(arr) buf = np.ones((y_size*fig_nrows+pad*(fig_nrows-1), x_size*fig_ncols+pad*(fig_ncols-1), 3),dtype=arr.dtype) for y in range(fig_nrows): for x in range(fig_ncols): if y*fig_ncols+x<fig_num: buf[y*y_size+y*pad:(y+1)*y_size+y*pad, x*x_size+x*pad:(x+1)*x_size+x*pad] = arr_color[y*fig_ncols+x,:,:,:3] img_plot = ax.imshow(buf, interpolation=interpolation) ax.axis('off')