Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/preparing_datasets/basic_dataset_conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ out = ('/local/data', 'oci://bucket/data')
| Numerical String | 'str_float' | `StrFloat` | stores in UTF-8 |
| Numerical String | 'str_decimal' | `StrDecimal` | stores in UTF-8 |
| Image | 'pil' | `PIL` | raw PIL image class ([link]((https://pillow.readthedocs.io/en/stable/reference/Image.html))) |
| Image | 'jpeg' | `JPEG` | PIL image as JPEG |
| Image | 'jpeg:quality' | `JPEG` | PIL image as JPEG, quality between 0 and 100 |
| Image | 'png' | `PNG` | PIL image as PNG |
| Pickle | 'pkl' | `Pickle` | arbitrary Python objects |
| JSON | 'json' | `JSON` | arbitrary data as JSON |
Expand All @@ -52,7 +52,7 @@ Here's an example where the field `x` is an image, and `y` is a class label, as
<!--pytest.mark.skip-->
```python
column = {
'x': 'jpeg',
'x': 'jpeg:90',
'y': 'int8',
}
```
Expand Down
26 changes: 24 additions & 2 deletions streaming/base/format/mds/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,29 @@ def decode(self, data: bytes) -> Image.Image:


class JPEG(Encoding):
"""Store PIL image as JPEG."""
"""Store PIL image as JPEG. Optionally specify quality."""

def __init__(self, quality: int = 75):
if not isinstance(quality, int):
raise ValueError('JPEG quality must be an integer')
if not (0 <= quality <= 100):
raise ValueError('JPEG quality must be between 0 and 100')
self.quality = quality

@classmethod
def from_str(cls, config: str) -> Self:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably also add a test for this too then, to also confirm what it should look like when used. Thanks!

"""Parse this encoding from string.

Args:
text (str): The string to parse.

Returns:
Self: The initialized Encoding.
"""
if config == '':
return cls()
else:
return cls(int(config))

def encode(self, obj: Image.Image) -> bytes:
self._validate(obj, Image.Image)
Expand All @@ -474,7 +496,7 @@ def encode(self, obj: Image.Image) -> bytes:
return f.read()
else:
out = BytesIO()
obj.save(out, format='JPEG')
obj.save(out, format='JPEG', quality=self.quality)
return out.getvalue()

def decode(self, data: bytes) -> Image.Image:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,30 @@ def test_jpeg_encode_decode(self, mode: str):
dec_data = dec_data.convert('I')
assert isinstance(dec_data, Image.Image)

@pytest.mark.parametrize('mode', ['L', 'RGB'])
def test_jpeg_encode_decode_with_quality(self, mode: str):
jpeg_enc = mdsEnc.JPEG(quality=50)
assert jpeg_enc.size is None

# Creating the (32 x 32) NumPy Array with random values
np_data = np.random.randint(255, size=(32, 32), dtype=np.uint32)
# Default image mode of PIL Image is 'I'
img = Image.fromarray(np_data).convert(mode)

# Test encode
enc_data = jpeg_enc.encode(img)
assert isinstance(enc_data, bytes)

# Test decode
dec_data = jpeg_enc.decode(enc_data)
dec_data = dec_data.convert('I')
assert isinstance(dec_data, Image.Image)

@pytest.mark.parametrize('quality', [-1, 101, 'foo'])
def test_jpeg_encode_decode_with_quality_invalid(self, quality: Any):
with pytest.raises(ValueError):
mdsEnc.JPEG(quality=quality)

@pytest.mark.parametrize('mode', ['L', 'RGB'])
def test_jpegfile_encode_decode(self, mode: str):
jpeg_enc = mdsEnc.JPEG()
Expand Down Expand Up @@ -224,6 +248,7 @@ def test_jpeg_encode_invalid_data(self, data: Any):
with pytest.raises(AttributeError):
jpeg_enc = mdsEnc.JPEG()
_ = jpeg_enc.encode(data)


@pytest.mark.parametrize('mode', ['I', 'L', 'RGB'])
def test_png_encode_decode(self, mode: str):
Expand Down