Source code for kalmus.utils.visualization_utils

""" Visualization Utility """

import numpy as np
import matplotlib.pyplot as plt
from skimage.color import hsv2rgb, rgb2hsv
from collections import Counter
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib as mpl

from kalmus.utils.artist import get_contrast_matrix_and_labeled_image


[docs]def show_color(color, figure_size=(9, 6), title="", axis_off=False, save_image=False, file_name="test.png", return_color=False): """ Plot the rgb color. :param color: 1D array contains the R, G, and B channel values :type color: numpy.ndarray :param figure_size: Size of the plotted figure :type figure_size: tuple :param title: The title of the plotted color :type title: str :param axis_off: True to set the axis of the plot figure off. False to have the axis on the side of \ figure. :type axis_off: bool :param save_image: True to save the plot figure into the path that user provide. False not to save \ the plot figure :type save_image: bool :param file_name: The filename of the saved figure. The default path to the saved image is test.png :type file_name: str :param return_color: Return the numpy array (shape==30 x 30 x 3) of color block if True. Directly plot the color \ if False :type return_color: bool :return: Return the numpy array of color block if return_color is True. :rtype: numpy.ndarray """ # Generate a block image where each pixel of this block image is the plotted RGB color clr = np.ones(shape=(30, 30, 3), dtype=np.uint8) * np.uint8(color) if return_color: return clr # Set up the figure plt.figure(figsize=figure_size) plt.imshow(clr) plt.title(title) # Whether to set the axis off if axis_off: plt.axis('off') # Whether to save the plot figure if save_image: plt.savefig(file_name) plt.show()
[docs]def show_colors_in_sequence(colors, figure_size=(9, 6), title="", axis_off=False, save_image=False, file_name="test.png", horizontal=True, return_color_sequence=False): """ Plot a sequence of RGB colors either horizontally or vertically in line :param colors: 2D array of RGB colors. Expected sequence shape==Number of colors x channels (3) :type colors: numpy.ndarray :param figure_size: the size of the plot figure :type figure_size: tuple :param title: the title of the plot :type title: str :param axis_off: True to set the axis of the figure off. False to have the axis with the figure :type axis_off: bool :param save_image: True to save the plot figure. False do not save figure :type save_image: bool :param file_name: the path of the saved figure :type file_name: str :param horizontal: True to plot the sequence of colors horizontally, False to plot the sequence of colors \ vertically :type horizontal: bool :param return_color_sequence: Return the numpy array (shape==30 x (30 * num_colors) x 3) of color sequence if True \ Directly plot the color sequence if False :type return_color_sequence: bool :return: Return the numpy array of the color sequence image if return_color_sequence is True. :rtype: numpy.ndarray """ # A chain/sequence of color blocks clr_chain = [] # Size of the color blocks in the display chain/sequence block_size = 30 # Append blocks into the color chain for i in range(len(colors)): color_block = np.ones(shape=(block_size, block_size, 3), dtype=np.uint8) * np.uint8(colors[i]) clr_chain.append(color_block) clr_chain = np.array(np.concatenate(clr_chain)) # Now the color chain is vertical. To have horizontal color chain, we transpose the sequence to rotate it if horizontal: clr_chain = clr_chain.transpose([1, 0, 2]) if return_color_sequence: return clr_chain plt.figure(figsize=figure_size) plt.imshow(clr_chain) plt.title(title) if axis_off: plt.axis("off") if save_image: plt.savefig(file_name) plt.show()
def show_color_matrix(color_2d_array, mode="truncate", figure_size=(9, 4), title="", axis_off=False, save_image=False, file_name="Block.png", return_matrix=False, return_figure=False): """ Show a matrix of colors (color barcodes). Two modes are available "truncate" and "padding" "truncate" mode will truncate the last row of the color matrix if it has different len with respect to the other rows. "padding" mode will pad white colors to the last row until the last row has the same length as the other rows :param color_2d_array: the input color matrix. Expected the shape of the color matrix to be row x col x channels \ channels should 3 (R, G, and B channels of RGB colors) :type color_2d_array: numpy.ndarray :param mode: mode for displaying color matrix. "truncate" or "padding" :type mode: str :param figure_size: the size of the figure :type figure_size: tuple :param title: the title of the plot :type title: str :param axis_off: True to set the axis of the figure off :type axis_off: bool :param save_image: True to save the plot figure, False do not save the figure :type save_image: bool :param file_name: the path to the saved figure :type file_name: str :param return_matrix: True to return the processed color matrix back False not to return anything :type return_matrix: bool :param return_figure: Return the plotted figure, if true. Directly plot the color matrix if false :type return_figure: bool :return: the processed color matrix. Depending on the display mode, the return color matrix will have padding \ white colors if mode is "padding", or the last row of color matrix will be truncated if mode is "truncate". \ Or, return the figure with plotted color matrix if return_figure is True. :rtype: numpy.ndarray if return_matrix, matplotlib.pyplot.Figure if return_figure """ assert len(color_2d_array.shape) == 3, "Only 2d array of three-channel color tuple/array accepted" assert color_2d_array.shape[-1] == 3, "The number of channels of color 2d array must be 3 (R, G, B)" assert mode == "truncate" or mode == "padding", "Invalid mode for displaying color matrix, two" \ "modes are available 'truncate' or 'padding'" # If the mode is padding if mode == "padding": # Flag variable False after appending the second column of colors to the color matrix first = True for color_col in color_2d_array: clr_chain = np.ones(shape=(1, 1, 3), dtype=np.uint8) * np.uint8(color_col[0]) for i in range(len(color_2d_array[0]) - 1): try: color = np.ones(shape=(1, 1, 3), dtype=np.uint8) * np.uint8(color_col[i + 1]) except: color = np.ones(shape=(1, 1, 3), dtype=np.uint8) * 255 clr_chain = np.append(clr_chain, color, axis=0) if first: clr_matrix = np.ones(shape=clr_chain.shape, dtype=np.uint8) * 255 first = False clr_matrix = np.append(clr_matrix, clr_chain, axis=1) # If the mode is truncate elif mode == "truncate": # If the last row has different length with the others if len(color_2d_array[-1]) != len(color_2d_array[-2]): # Truncate the last row clr_matrix = np.array(color_2d_array[:-1]).astype("uint8") else: # Otherwise keep the matrix clr_matrix = np.array(color_2d_array).astype("uint8") # Transpose the matrix to have a horizontal display clr_matrix = clr_matrix.transpose([1, 0, 2]) if return_matrix: return clr_matrix fig = plt.figure(figsize=figure_size) plt.imshow(clr_matrix) plt.title(title) if axis_off: plt.axis('off') if save_image: plt.savefig(file_name) if return_figure: return fig plt.show()
[docs]def show_colors_in_cube(colors, figure_size=(8, 8), tight_plot=True, grid_off=True, background_off=True, return_figure=False, sampling=-1, return_sampled_colors=False): """ Show a sequence of RGB colors in cubic RGB space (e.g. R axis (x axis), G axis (y axis), and B axis (z axis)) see https://matplotlib.org/3.1.1/gallery/mplot3d/scatter3d.html for more references ipympl is required for this function to work in the jupyter notebook :param colors: A sequence of colors to display in cubic space :type colors: numpy.ndarray :param figure_size: the size of the figure :type figure_size: tuple :param tight_plot: Remove the white margin around the plot if true \ Keep the original white margin if false :type tight_plot: bool :param grid_off: Remove the grid of plot if true \ Keep the original grid if false :type grid_off: bool :param background_off: Use the transparent background if true \ Keep the original background if false :type background_off: bool :param return_figure: Return the plotted figure and axes, if true \ Directly plot the cube, if false :type return_figure: bool :param sampling: Number of data points plot in the 3D cube. -1 to include all data in the plot :type sampling: int :param return_sampled_colors: Return the sampled out colors if True. Directly plot the colors in RGB cube if False :type return_sampled_colors: bool :return: return the figure and axes with plotted figure if return_figure is True \ or, return the sampled out colors if return_sampled_colors is True :rtype: numpy.ndarray """ assert colors.shape[-1] == 3, "The input colors must be a 2D numpy array with RGB colors where " \ "shape of array is number of colors x channels(3)" if sampling > 0: sample_indices = np.random.choice(np.arange(colors.size // colors.shape[-1]), sampling, replace=False) colors = colors[sample_indices] # colors is the N*M x 3 version of the image. colors = np.concatenate([np.expand_dims(IC, axis=1) for IC in [colors[..., 0].ravel(), colors[..., 1].ravel(), colors[..., 2].ravel()]], axis=1) if return_sampled_colors: return colors fig = plt.figure(figsize=figure_size) ax = fig.add_subplot(111, projection='3d') if colors.max() > 1: ax.scatter(colors[..., 0], colors[..., 1], colors[..., 2], c=colors[..., :].astype("float32") / 255) else: ax.scatter(colors[..., 0], colors[..., 1], colors[..., 2], c=colors[..., :]) # Label the axes. ax.set_xlabel('Red') ax.set_ylabel('Green') ax.set_zlabel('Blue') ax.grid(not grid_off) if tight_plot: fig.subplots_adjust(left=0.10, right=0.90, bottom=0.10, top=0.90) if background_off: ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) if return_figure: return fig, ax plt.show()
[docs]def show_high_contrast_region(image, minimum_segment_size=0.0004, figsize=(6, 4), return_region_image=False): """ Plot the high contrast region (brightness contrast) of the image. image is segmented using the gradient based watershed segmentation :param image: input color image shape=(row (height), col (width), channel(3)) :type image: numpy.ndarray :param minimum_segment_size: The minimum size (in relative ratio [0, 1]) for the segments in the segmented image :type minimum_segment_size: float :param figsize: the size of the plot figure :type figsize: tuple :param return_region_image: Return the image with only high contrast region if True. \ Directly plot the image if False :type return_region_image: bool :return: If return_region_image is True, return the the image with only high contrast region. \ region_image.shape == image.shape :rtype: numpy.ndarray """ contrast_matrix, labeled_image = get_contrast_matrix_and_labeled_image(image, minimum_segment_size=minimum_segment_size) high_contrast_region_index = np.sum(contrast_matrix, axis=1).argmax() contrast_part = extract_region_with_index(image, high_contrast_region_index + 1, labeled_image) if return_region_image: return contrast_part plt.figure(figsize=figsize) plt.imshow(contrast_part) plt.show()
[docs]def show_low_contrast_region(image, minimum_segment_size=0.0004, figsize=(6, 4), return_region_image=False): """ Plot the low contrast region (brightness contrast) of the image. image is segmented using the gradient based watershed segmentation :param image: input color image shape=(row (height), col (width), channel(3)) :type image: numpy.ndarray :param minimum_segment_size: The minimum size (in relative ratio [0, 1]) for the segments in the segmented image :type minimum_segment_size: float :param figsize: the size of the plot figure :type figsize: tuple :param return_region_image: Return the image with only low contrast region if True. \ Directly plot the image if False :type return_region_image: bool :return: If return_region_image is True, return the the image with only low contrast region. \ region_image.shape == image.shape :rtype: numpy.ndarray """ contrast_matrix, labeled_image = get_contrast_matrix_and_labeled_image(image, minimum_segment_size=minimum_segment_size) low_contrast_region_index = np.sum(contrast_matrix, axis=1).argmin() contrast_part = extract_region_with_index(image, low_contrast_region_index + 1, labeled_image) if return_region_image: return contrast_part plt.figure(figsize=figsize) plt.imshow(contrast_part) plt.show()
[docs]def extract_region_with_index(image, region_index, labeled_image): """ Helper function that masked the labeled image on the input image, and extract the region with corresponding region index in the labeled image from the input image :param image: input image shape=(row(height), col(width)) :type image: numpy.ndarray :param region_index: the extracted region index in the labeled image :type region_index: int :param labeled_image: the labeled image of the corresponding input image shape=image.shape :type labeled_image: numpy.ndarray :return: Image with only region with corresponding index on labeled_image. :rtype: numpy.ndarray """ extract_part = image.copy() # If the extracted region is bright, paint the background with black if np.average(image[labeled_image == region_index]) / float(image.max()) > 0.5: extract_part[labeled_image != region_index] = np.array([0, 0, 0], dtype=image.dtype) # If the extracted region is dark, paint the background with white else: # Get the max of image max = image.max() # If the image is 0-1 normalized if max <= 1: # White == [1, 1, 1] max = 1 # If the image is 0-255 normalzied (uint8) else: # White == [255, 255, 255] max = 255 extract_part[labeled_image != region_index] = np.array([max, max, max], dtype=image.dtype) return extract_part
[docs]def show_colors_in_hue_light_scatter_plot(colors, figure_size=(10, 5), return_figure=False, remove_border=False, saturation_threshold=0.15): """ Show a sequence of RGB colors in a Hue vs. Light Scatter Plot (Hue on x-axis and Light on y-axis). Colors are assumed to be in the RGB colorspace and will be converted to the HSV color space within this function. :param colors: A sequence of colors to display in cubic space :type colors: numpy.ndarray :param figure_size: the size of the figure :type figure_size: tuple :param return_figure: Return the plotted figure and axes, if true \ Directly plot the cube, if false :type return_figure: bool :param remove_border: Remove the frame border of the plot if true \ Keep the original frame border if false :type remove_border: bool :param saturation_threshold: The color of which the saturation value \ is under the threshold will be excluded \ from the plots :type saturation_threshold: float :return: return the figure and axes with plotted figure if return_figure is True :rtype: tuple (matplotlib.pyplot.Figure, matplotlib.pyplot.Axes) """ # Convert colors to a numpy array colors = np.array(colors) # If the RGB color is in range [0, 255] if colors.max() > 1: # Convert it to float and normalize it to [0, 1] normalized_colors = colors.astype("float") / 255 # Convert RGB color to HSV color space hsv_colors = rgb2hsv(normalized_colors.reshape(-1, 1, 3)) hsv_colors = hsv_colors[hsv_colors[..., 1] > saturation_threshold] # Get the Hue value of each color hue = hsv_colors[..., 0] * 360 # Get the Light value of each color bri = hsv_colors[..., 2] # Convert them to list hue_ls = hue.astype("int").ravel().tolist() bri_ls = bri.ravel().tolist() combos = list(zip(hue_ls, bri_ls)) # Get the weights/frequency of each unique color weight_counter = Counter(combos) weights = [weight_counter[(hue_ls[i], bri_ls[i])] for i, _ in enumerate(bri_ls)] weights = np.array(weights).astype("float") weights -= weights.min() weights /= weights.max() weights *= 95 weights += 5 hsv_colors_cp = hsv_colors.copy() hsv_colors_cp[..., 1] = 1 hsv_colors_cp[..., 2] = hsv_colors_cp[..., 2] ** (2 / 3) rgb_colors = hsv2rgb(hsv_colors_cp.reshape(-1, 1, 3)) rgb_colors = rgb_colors fig, ax = plt.subplots(1, 1, figsize=figure_size) ax.scatter(hue, bri, s=weights, marker="s", c=rgb_colors.reshape(-1, 3)) divider = make_axes_locatable(plt.gca()) ax_cb1 = divider.new_horizontal(size="5%", pad="1%") ax_cb2 = divider.append_axes("bottom", size="5%", pad="12%") cb1 = mpl.colorbar.ColorbarBase(ax_cb1, cmap=mpl.cm.get_cmap('binary_r', 256), orientation='vertical') cb2 = mpl.colorbar.ColorbarBase(ax_cb2, cmap=mpl.cm.get_cmap('hsv', 360), orientation='horizontal') cb1.set_ticks([]) cb2.set_ticks([]) ax.set_xlabel("Hue (0 - 360)") ax.set_ylabel("Light (0 - 1)") ax.set_ylim(-0.01, 1.01) ax.set_xlim(0, 359) fig.add_axes(ax_cb1) fig.add_axes(ax_cb2) if remove_border: ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) if return_figure: return fig, ax else: plt.show()
[docs]def show_colors_in_hue_light_3d_bar_plot(colors, figure_size=(6, 6), hue_resolution=10, bri_resolution=0.02, return_figure=False, grid_off=True, background_off=True, shaded=False, tight_plot=True, axes=None, saturation_threshold=0.15): """ Show a sequence of RGB colors in a Hue vs. Light vs. Counts 3D bar Plot (Hue on x-axis, Light on y-axis, and counts/frequency of color on z-axis). Colors are assumed to be in the RGB colorspace and will be converted to the HSV color space within this function. :param colors: A sequence of colors to display in cubic space :type colors: numpy.ndarray :param figure_size: the size of the figure :type figure_size: tuple :param hue_resolution: Resolution of the bar plot on the hue axis (hue 0 - 360) :type hue_resolution: float :param bri_resolution: Resolution of the bar plot on the light axis (light 0 - 1) :type bri_resolution: float :param return_figure: Return the plotted figure and axes, if true \ Directly plot the cube, if false :type return_figure: bool :param grid_off: Remove the grid of plot if true \ Keep the original grid if false :type grid_off: bool :param background_off: Use the transparent background if true \ Keep the original background if false :type background_off: bool :param shaded: Shaded the 3D bar object if true \ Unshaded if false :type shaded: bool :param tight_plot: Remove the white margin around the plot if true \ Keep the original white margin if false :type tight_plot: bool :param axes: The matplotlib Axes object. If given function will plot the \ bar plot in the given Axes :type axes: matplotlib.pyplot.Axes :param saturation_threshold: The color of which the saturation value \ is under the threshold will be excluded \ from the plots :type saturation_threshold: float :return: return the figure and axes with plotted figure if return_figure is True :rtype: tuple (matplotlib.pyplot.Figure, matplotlib.pyplot.Axes) """ # Convert colors to a numpy array colors = np.array(colors) # If the RGB color is in range [0, 255] if colors.max() > 1: # Convert it to float and normalize it to [0, 1] normalized_colors = colors.astype("float") / 255 # Convert RGB color to HSV color space hsv_colors = rgb2hsv(normalized_colors.reshape(-1, 1, 3)) hsv_colors = hsv_colors[hsv_colors[..., 1] > saturation_threshold] # Get the Hue value of each color hue = hsv_colors[..., 0] * 360 # Get the Light value of each color bri = hsv_colors[..., 2] # Discretizing the Hue value based on the given resolution hue_int = hue / hue_resolution hue_int = hue_int.astype(int) hue_int = hue_int * hue_resolution # Discretizing the light value based on the given resolution bri_resolution = bri_resolution * 100 bri_int = bri * 100 bri_int = bri_int / bri_resolution bri_int = bri_int.astype(int) bri_int = bri_int * bri_resolution # Get the unique colors and their counts/frequency in the given colors hue_bri = np.array(list(zip(hue_int, bri_int))).astype("float") unique_colors, counts = np.unique(hue_bri, axis=0, return_counts=True) # Normalize the HSV colors norm_colors = unique_colors.astype("float") norm_colors[..., 0] = norm_colors[..., 0] / 360 norm_colors[..., 1] = norm_colors[..., 1] / 100 if shaded: norm_colors[..., 1] = np.clip((norm_colors[..., 1] ** (1 / 2)) * 1.25, 0, 1) else: norm_colors[..., 1] = np.clip(norm_colors[..., 1] * 1.1, 0, 1) ** (2 / 3) # Convert them back to the RGB colorspace hsv_colors = np.hstack((norm_colors[..., 0].reshape(-1, 1), np.ones(shape=(norm_colors.shape[0], 1)), norm_colors[..., 1].reshape(-1, 1))) rgb_colors = hsv2rgb(hsv_colors.reshape(-1, 1, 3)) rgb_colors = rgb_colors.reshape(-1, 3) if axes is None: fig = plt.figure(figsize=figure_size) ax = fig.add_subplot(111, projection='3d') else: ax = axes x = unique_colors[..., 0] y = unique_colors[..., 1] / 100 top = counts bottom = np.zeros_like(top) width = hue_resolution * 0.9 depth = bri_resolution * 0.9 / 100 ax.bar3d(x, y, bottom, width, depth, top, shade=shaded, color=rgb_colors) ax.set_xlabel("Hue (0 - 360)") ax.set_ylabel("Light (0 - 1)") ax.set_zlabel("Number of Frames") ax.grid(not grid_off) if tight_plot and (axes is None): fig.subplots_adjust(left=0.10, right=0.90, bottom=0.10, top=0.90) if background_off: ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) if axes: return if return_figure: return fig, ax else: plt.show()