Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion torchdata/nodes/_populate_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import queue
import threading
from typing import Any, Dict, Optional, Union
Expand Down Expand Up @@ -60,7 +61,10 @@ def _put(
assert (
isinstance(snapshot_frequency, int) and snapshot_frequency >= 0
), f"snapshot_frequency must be non-negative integer! Got {snapshot_frequency}"
snapshot_store.append_initial_snapshot(snapshot=source.state_dict())
snapshot = source.state_dict()
if snapshot is not None:
snapshot = copy.deepcopy(snapshot)
snapshot_store.append_initial_snapshot(snapshot=snapshot)
except Exception:
e = StartupExceptionWrapper(where="in _populate_queue startup for device")
snapshot_store.append_initial_snapshot(snapshot=e)
Expand All @@ -76,6 +80,8 @@ def _put(
snapshot = None
if snapshot_frequency > 0 and yielded % snapshot_frequency == 0:
snapshot = source.state_dict()
if snapshot is not None:
snapshot = copy.deepcopy(snapshot)
Comment thread
alexdremov marked this conversation as resolved.
_put(item, block=False, snapshot=snapshot)
except StopIteration as e:
_put(e, block=False)
Expand Down