Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/preparing_datasets/basic_dataset_conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ out = ('/local/data', 'oci://bucket/data')
| Numerical String | 'str_float' | `StrFloat` | stores in UTF-8 |
| Numerical String | 'str_decimal' | `StrDecimal` | stores in UTF-8 |
| Image | 'pil' | `PIL` | raw PIL image class ([link]((https://pillow.readthedocs.io/en/stable/reference/Image.html))) |
| Image | 'jpeg' | `JPEG` | PIL image as JPEG |
| Image | 'jpeg:quality' | `JPEG` | PIL image as JPEG, quality between 0 and 100 |
| Image | 'png' | `PNG` | PIL image as PNG |
| Pickle | 'pkl' | `Pickle` | arbitrary Python objects |
| JSON | 'json' | `JSON` | arbitrary data as JSON |
Expand All @@ -52,7 +52,7 @@ Here's an example where the field `x` is an image, and `y` is a class label, as
<!--pytest.mark.skip-->
```python
column = {
'x': 'jpeg',
'x': 'jpeg:90',
'y': 'int8',
}
```
Expand Down
24 changes: 23 additions & 1 deletion streaming/base/format/mds/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,28 @@ def decode(self, data: bytes) -> Image.Image:
class JPEG(Encoding):
"""Store PIL image as JPEG."""

def __init__(self, quality: int = 75):
assert 0 <= quality <= 100
Comment thread
cabreraalex marked this conversation as resolved.
Outdated
self.quality = quality

@classmethod
def from_str(cls, text: str) -> Self:
Comment thread
cabreraalex marked this conversation as resolved.
Outdated
"""Parse this encoding from string.

Args:
text (str): The string to parse.

Returns:
Self: The initialized Encoding.
"""
args = text.split(':') if text else []
assert len(args) in {0, 1}
if len(args) == 1:
quality = int(args[0])
else:
quality = 75
Comment thread
snarayan21 marked this conversation as resolved.
Outdated
return cls(quality)

def encode(self, obj: Image.Image) -> bytes:
self._validate(obj, Image.Image)
if isinstance(obj, JpegImageFile) and hasattr(obj, 'filename'):
Expand All @@ -474,7 +496,7 @@ def encode(self, obj: Image.Image) -> bytes:
return f.read()
else:
out = BytesIO()
obj.save(out, format='JPEG')
obj.save(out, format='JPEG', quality=self.quality)
return out.getvalue()

def decode(self, data: bytes) -> Image.Image:
Expand Down