Add register_tracker_class to register custom trackers by name#4060
Add register_tracker_class to register custom trackers by name#4060javierdejesusda wants to merge 3 commits into
Conversation
SunMarc
left a comment
There was a problem hiding this comment.
Left a couple of comments thanks !
| for custom_name in log_with: | ||
| if custom_name not in LoggerType and custom_name in LOGGER_TYPE_TO_CLASS: | ||
| tracker_init = LOGGER_TYPE_TO_CLASS[str(custom_name)] | ||
| if tracker_init.requires_logging_directory and logging_dir is None: | ||
| raise ValueError(f"Logging with `{custom_name}` requires a `logging_dir` to be passed in.") | ||
| if custom_name not in loggers: | ||
| loggers.append(custom_name) |
There was a problem hiding this comment.
well here, log_with is either "all" or "LoggerType.ALL" no ?
There was a problem hiding this comment.
log_with is the full list here, so it can be e.g. ["my_tracker", "all"] the loop picks up a tracker registered by name when it's listed alongside "all", since "all" on its own only pulls in the built-ins. Happy to drop it and keep just instance + "all" if you'd prefer the smaller surface.
There was a problem hiding this comment.
yeah, i think it is better to drop it. it's a bit counter intuitive to pass ["my_tracker", "all"]
There was a problem hiding this comment.
Done in 155194d: "all" is back to main's behavior, so a registered name passed alongside "all" is no longer picked up. Selecting by name still works on its own via log_with="my_tracker".
| assert data == truth | ||
|
|
||
|
|
||
| class MyRegisteredTracker(GeneralTracker): |
There was a problem hiding this comment.
Will trim these to the shared fixture plus the core cases, together with the filter_trackers decision above (some depend on it).
1579652 to
893e814
Compare
Add register_tracker_class to register a custom GeneralTracker subclass into the LOGGER_TYPE_TO_CLASS registry so it can be selected by string name in Accelerator(log_with=...), the same way as the built-in trackers. Expose a module-level register_tracker_class in accelerate.tracking. Validate that the class subclasses GeneralTracker and defines name and requires_logging_directory, and warn when an existing name is overwritten. Extend filter_trackers so registered custom names resolve through the registry, including when combined with 'all', with the same logging-directory guard as the built-in trackers. Add tests and documentation for registering and using a custom tracker by name. Fixes huggingface#2734
893e814 to
3b9c71b
Compare
Custom trackers registered via register_tracker_class are selected by listing their name explicitly, not when combined with "all". The "all" branch of filter_trackers returns to its original behavior (instances plus available built-ins). Removes the test covering the dropped behavior and its now-unused filter_trackers import.
register_tracker_class is documented to run before Accelerator()/PartialState(), so the accelerate logger (a MultiProcessAdapter) raises RuntimeError when the name-collision warning fires (shadowing a built-in or re-registering a name). Use warnings.warn so the warning is emitted regardless of state, matching how tracker warnings are reported elsewhere. The test now exercises the real warning instead of mocking the logger.
|
Extra commit d22cb63: the overwrite warning used logger.warning, but accelerate's logger raises RuntimeError before PartialState/Accelerator init, so shadowing or re-registering a name in the documented "register first" flow would crash instead of warn. Switched that call to warnings.warn (matches accelerator.py) and the test now asserts via assertWarns. |
What does this PR do?
Adds
register_tracker_class, a small public API to register a customGeneralTrackersubclass into the internalLOGGER_TYPE_TO_CLASSregistry, so it can be selected by string name inAccelerator(log_with=...), exactly like the built-in trackers.This removes the boilerplate described in #2734, where users currently have to special-case custom trackers when building
log_with:With this change:
The API matches what was agreed in the issue thread (
nametaken from the class attribute, per @muellerzr's suggestion). It is exposed as a module-levelregister_tracker_classinaccelerate.tracking.Details
GeneralTrackerand definesnameandrequires_logging_directory.filter_trackersso registered names resolve through the registry, including when combined with"all", with the same logging-directory guard applied to the built-in trackers.requires_logging_directorypath, overwrite warnings, validation errors, and the"all"combination, plus usage-guide and API-reference docs.Fixes #2734
Before submitting
Who can review?
@SunMarc @BenjaminBossan