Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/updating.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ Removed `label` parameter since size marker data is now stored as metadata in th
Deprecated the function in favor of the new
[`plantcv.transform.detect_color_card`](transform_detect_color_card.md) function.

#### plantcv.visualize.pixel_scatter_plot

Changed `paths_to_imgs` argument to `source` to reflect that it can use a `numpy.ndarray`, `str` path,
or a list of paths where it previously only could use a list of paths.

#### plantcv.visualize.time_lapse_video

Deprecated the function to enable compatibility with the opencv-headless package. Will be readded in a future release.
Expand Down Expand Up @@ -1575,7 +1580,8 @@ pages for more details on the input and output variable types.
#### plantcv.visualize.pixel_scatter_plot

* pre v4.0: NA
* post v4.0: fig, ax = **pcv.visualize.pixel_scatter_plot**(*paths_to_imgs, x_channel, y_channel*)
* pre v5.0: fig, ax = **pcv.visualize.pixel_scatter_plot**(*paths_to_imgs, x_channel, y_channel*)
* post v5.0: fig, ax = **pcv.visualize.pixel_scatter_plot**(*source, x_channel, y_channel, n=20, ext="png"*)

#### plantcv.visualize.tile

Expand Down
16 changes: 10 additions & 6 deletions docs/visualize_pixel_scatter_vis.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

This function plots a 2D pixel scatter plot visualization for a dataset of images. The horizontal and vertical coordinates are defined by the intensity of the pixels in the specified channels. The color of each dot is given by the original RGB color of the pixel.

**plantcv.visualize.pixel_scatter_plot**(*paths_to_imgs, x_channel, y_channel*)
**plantcv.visualize.pixel_scatter_plot**(*source, x_channel, y_channel, n=20, ext="png"*)

**returns** fig, ax

- **Parameters:**
- paths_to_imgs - List of paths to the images.
- source - Image as a numpy array, string file path to a directory of images, or list of paths to the images.
- x_channel - Channel to use for the horizontal coordinate of the scatter plot.
Options: 'R', 'G', 'B', 'l', 'a', 'b', 'h', 's', 'v', 'c', 'm', 'y', 'k', 'gray', and 'index'.
- y_channel - Channel to use for the vertical coordinate of the scatter plot.
Options: 'R', 'G', 'B', 'l', 'a', 'b', 'h', 's', 'v', 'c', 'm', 'y', 'k', 'gray', and 'index'.
- n - Max number of images to use if `source` is a filepath.
- ext - Image file extension to search for if `source` is a filepath.


- **Context:**
Expand All @@ -37,9 +39,9 @@ This function plots a 2D pixel scatter plot visualization for a dataset of image

from plantcv import plantcv as pcv

fig1, ax1 = pcv.visualize.pixel_scatter_plot(paths_to_imgs=file_paths, x_channel='index', y_channel='G')
fig1, ax1 = pcv.visualize.pixel_scatter_plot(source=file_paths, x_channel='index', y_channel='G')

fig2, ax2 = pcv.visualize.pixel_scatter_plot(paths_to_imgs=file_paths, x_channel='index', y_channel='s')
fig2, ax2 = pcv.visualize.pixel_scatter_plot(source=file_paths, x_channel='index', y_channel='s')

```

Expand All @@ -61,9 +63,11 @@ fig2, ax2 = pcv.visualize.pixel_scatter_plot(paths_to_imgs=file_paths, x_channel

from plantcv import plantcv as pcv

fig1, ax1 = pcv.visualize.pixel_scatter_plot(paths_to_imgs=file_paths, x_channel='b', y_channel='a')
fig1, ax1 = pcv.visualize.pixel_scatter_plot(source=file_paths, x_channel='b', y_channel='a')

fig2, ax2 = pcv.visualize.pixel_scatter_plot(paths_to_imgs=file_paths, x_channel='G', y_channel='b')
fig2, ax2 = pcv.visualize.pixel_scatter_plot(source="/path/to/images/", x_channel='G', y_channel='b')

fig3, ax3 = pcv.visualize.pixel_scatter_plot(source=img, x_channel='G', y_channel='b')

```

Expand Down
102 changes: 82 additions & 20 deletions plantcv/plantcv/visualize/pixel_scatter_vis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Visualize a scatter plot of pixels

import os
import numpy as np
import cv2
from matplotlib import pyplot as plt
Expand Down Expand Up @@ -45,29 +45,32 @@ def _not_valid(*args):
return fatal_error("channel not valid, use R, G, B, l, a, b, h, s, v, c, m, y, k, gray, or index")


def pixel_scatter_plot(paths_to_imgs, x_channel, y_channel):
def pixel_scatter_plot(source, x_channel, y_channel, n=20, ext="png"):
"""
Plot a 2D pixel scatter plot visualization for a dataset of images.
The horizontal and vertical coordinates are defined by the intensity of the
pixels in the specified channels.
The color of each dot is given by the original RGB color of the pixel.

Inputs:
paths_to_imgs = List of paths to the images
x_channel = Channel to use for the horizontal coordinate of the scatter plot.
Options: 'R', 'G', 'B', 'l', 'a', 'b', 'h', 's', 'v', 'c', 'm', 'y', 'k', 'gray', and 'index'
y_channel = Channel to use for the vertical coordinate of the scatter plot.
Options: 'R', 'G', 'B', 'l', 'a', 'b', 'h', 's', 'v', 'c', 'm', 'y', 'k', 'gray', and 'index'

Returns:
fig = matplotlib pyplot Figure object of the visualization
ax = matplotlib pyplot Axes object of the visualization

:param paths_to_imgs: list of str
:param x_channel: str
:param y_channel: str
:return fig: matplotlib.pyplot Figure object
:return ax: matplotlib.pyplot Axes object
Parameters
----------
source : list, numpy.ndarray, or str,
List of paths to the images, an image as a numpy array, or a path to a starting directory to find images in
x_channel : str,
Channel to use for the horizontal coordinate of the scatter plot.
Options: 'R', 'G', 'B', 'l', 'a', 'b', 'h', 's', 'v', 'c', 'm', 'y', 'k', 'gray', and 'index'
y_channel : str,
Channel to use for the vertical coordinate of the scatter plot.
Options: 'R', 'G', 'B', 'l', 'a', 'b', 'h', 's', 'v', 'c', 'm', 'y', 'k', 'gray', and 'index'
n : int,
max number of images to use if source is a string
ext : str,
image file extension to search for if source is a string

Returns
-------
fig : matplotlib pyplot Figure object of the visualization
ax : matplotlib pyplot Axes object of the visualization
"""
# dictionary returns the function that gets the required image channel
channel_dict = {
Expand All @@ -87,13 +90,24 @@ def pixel_scatter_plot(paths_to_imgs, x_channel, y_channel):
'y': _rgb2cmyk,
'k': _rgb2cmyk
}
if isinstance(source, np.ndarray):
fig, ax = _px_scatter_from_img(source, x_channel, y_channel, channel_dict)
return fig, ax
# if not an image then keep going
paths_to_imgs = source
N = len(paths_to_imgs)
if isinstance(source, str):
N = n
paths_to_imgs = []
for root, _, files in os.walk(source):
for file in files:
if file.lower().endswith(ext) and len(paths_to_imgs) < n:
paths_to_imgs.append(os.path.join(root, file))

# store debug mode
debug = params.debug
params.debug = None

N = len(paths_to_imgs)

fig, ax = plt.subplots()
# load and plot the set of images sequentially
for p in paths_to_imgs:
Expand Down Expand Up @@ -126,3 +140,51 @@ def pixel_scatter_plot(paths_to_imgs, x_channel, y_channel):
params.debug = debug

return fig, ax


def _px_scatter_from_img(source, x_channel, y_channel, channel_dict):
"""Make pixel scatter plot from an image

Parameters
----------
source : numpy.ndarray
List of paths to the images, an image as a numpy array, or a path to a starting directory to find images in
x_channel : str,
Channel to use for the horizontal coordinate of the scatter plot.
Options: 'R', 'G', 'B', 'l', 'a', 'b', 'h', 's', 'v', 'c', 'm', 'y', 'k', 'gray', and 'index'
y_channel : str,
Channel to use for the vertical coordinate of the scatter plot.
Options: 'R', 'G', 'B', 'l', 'a', 'b', 'h', 's', 'v', 'c', 'm', 'y', 'k', 'gray', and 'index'
channel_dict : dict,
dictionary of functions to pull channels. Defined internally in user facing function.

Returns
-------
fig : matplotlib pyplot Figure object of the visualization
ax : matplotlib pyplot Axes object of the visualization
"""
fig, ax = plt.subplots()
h, _, c = source.shape
# resizing to predetermined width to reduce the number of pixels
ratio = h/IMG_WIDTH
img_height = int(IMG_WIDTH*ratio)
# nearest interpolation avoids mixing pixel values
sub_img = cv2.resize(source, (IMG_WIDTH, img_height), interpolation=cv2.INTER_NEAREST)

# organize the channels as RGB to use as facecolor for the markers
sub_img_rgb = cv2.cvtColor(sub_img, cv2.COLOR_BGR2RGB)
fcolors = sub_img_rgb.reshape(img_height*IMG_WIDTH, c)/255

# get channels
sub_img_x_ch = channel_dict.get(x_channel, _not_valid)(sub_img, x_channel)
sub_img_y_ch = channel_dict.get(y_channel, _not_valid)(sub_img, y_channel)

ax.scatter(sub_img_x_ch.reshape(-1),
sub_img_y_ch.reshape(-1),
alpha=0.05, s=MAX_MARKER_SIZE,
edgecolors=None, facecolors=fcolors)

plt.xlabel(x_channel)
plt.ylabel(y_channel)

return fig, ax
32 changes: 30 additions & 2 deletions tests/plantcv/visualize/test_pixel_scatter_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_pixel_scatter_plot(ch, tmpdir):
path_to_img = os.path.join(cache_dir, 'tmp_img.png')
cv2.imwrite(path_to_img, img)
# test the function with a list of one path to the random image
_, _ = pixel_scatter_plot(paths_to_imgs=[path_to_img], x_channel=ch, y_channel='index')
_, _ = pixel_scatter_plot(source=[path_to_img], x_channel=ch, y_channel='index')
assert 1


Expand All @@ -33,4 +33,32 @@ def test_pixel_scatter_plot_wrong_ch(tmpdir):
cv2.imwrite(path_to_img, img)
# test the function with channel parameter that is not an option
with pytest.raises(RuntimeError):
_, _ = pixel_scatter_plot(paths_to_imgs=[path_to_img], x_channel='wrong_ch', y_channel='index')
_, _ = pixel_scatter_plot(source=[path_to_img], x_channel='wrong_ch', y_channel='index')


def test_pixel_scatter_plot_str_source(tmpdir):
"""Test for PlantCV."""
# Create a tmp directory
cache_dir = tmpdir.mkdir("cache")
for i in range(2):
rng = np.random.default_rng()
img_size = (10, 10, 3)
# create a random image and write it to the temp directory
img = rng.integers(low=0, high=255, size=img_size, dtype=np.uint8, endpoint=True)
path_to_img = os.path.join(cache_dir, f'tmp_img_{i}.png')
cv2.imwrite(path_to_img, img)
# test the function with an str path
_, _ = pixel_scatter_plot(source=str(cache_dir), x_channel="R", y_channel="index")
assert 1


def test_pixel_scatter_plot_img_source():
"""Test for PlantCV."""
# Create an image
rng = np.random.default_rng()
img_size = (10, 10, 3)
# create a random image and write it to the temp directory
img = rng.integers(low=0, high=255, size=img_size, dtype=np.uint8, endpoint=True)
# test the function with an str path
_, _ = pixel_scatter_plot(source=img, x_channel="l", y_channel="a")
assert 1