Skip to content

Fix tied embeddings#2237

Open
kunal-vaishnavi wants to merge 5 commits into
mainfrom
kvaishnavi/fix-tied-embeds
Open

Fix tied embeddings#2237
kunal-vaishnavi wants to merge 5 commits into
mainfrom
kvaishnavi/fix-tied-embeds

Conversation

@kunal-vaishnavi

Copy link
Copy Markdown
Contributor

Description

This PR rewrites how tied embeddings are determined in the model builder. It also adds a unit test file to ensure that the different cases for tied embeddings are tested in the CIs.

Motivation and Context

There are still some small edge cases that haven't been validated for tied embeddings.

Copilot AI review requested due to automatic review settings June 23, 2026 01:10
@kunal-vaishnavi kunal-vaishnavi requested a review from a team as a code owner June 23, 2026 01:10

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Python model builder’s tied-embedding initialization logic and adds CI coverage to exercise tied-embedding edge cases via pytest.

Changes:

  • Refactors tied-embedding determination in the model builder to compute explicit tied_quantized_embeddings / tied_unquantized_embeddings flags and uses them in embedding construction.
  • Adds a new pytest module to cover shared-embedding configuration permutations (plus a small unit test for make_matmul_int4 behavior).
  • Updates the Python test runner entrypoint to run the builder/ and models/ pytest suites as part of the standard pipeline run.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
src/python/py/models/builders/base.py Reworks tied-embedding / quantization eligibility flags and switches embedding logic to use the new tied flags.
test/python/builder/test_tied_embeddings.py Adds unit tests for tied-embedding flag behavior and INT4 MatMul fallback/emit behavior.
test/python/test_onnxruntime_genai.py Expands the invoked pytest targets to include builder/ and models/ test suites.

Comment on lines +543 to +548
if shared_embeddings:
self.tied_quantized_embeddings = self.quantized_embeds and self.quantized_lm_head
self.tied_unquantized_embeddings = not self.tied_quantized_embeddings
else:
self.tied_unquantized_embeddings = False
self.tied_quantized_embeddings = False
Comment on lines +127 to +130
(True, True, True, False),
(True, False, False, True),
(False, True, False, True),
(False, False, False, True),
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

OGA export generating an invalid model when a lm_head unquantized model is exported

2 participants