Skip to content
4 changes: 2 additions & 2 deletions backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
TosaPipelineBI,
TosaPipelineMI,
)

from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.examples.models.llama.export_llama_lib import (
build_args_parser,
get_llama_model,
)

from executorch.extension.llm.export.config.llm_config import LlmConfig

input_t = Tuple[torch.Tensor]

# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py
Expand Down
4 changes: 2 additions & 2 deletions examples/apple/mps/scripts/mps_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from executorch.devtools.bundled_program.serialize import (
serialize_from_bundled_program_to_flatbuffer,
)

from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
Expand All @@ -31,6 +29,8 @@
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.extension.export_util.utils import export_to_edge, save_pte_program

from executorch.extension.llm.export.config.llm_config import LlmConfig

from ....models import MODEL_NAME_TO_MODEL
from ....models.model_factory import EagerModelFactory

Expand Down
20 changes: 7 additions & 13 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ runtime.python_library(
name = "llama_transformer",
srcs = [
"llama_transformer.py",
"rope.py",
"attention.py",
"model_args.py",
"norm.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama",
Expand All @@ -26,23 +22,21 @@ runtime.python_library(
],
deps = [
"//caffe2:torch",
"//executorch/extension/llm/modeling/text_decoder:text_decoder_attention",
"//executorch/extension/llm/modeling/text_decoder:text_decoder_model_args",
"//executorch/extension/llm/modeling/text_decoder:text_decoder_norm",
"//executorch/extension/llm/modeling/text_decoder:text_decoder_rope",
],
)

runtime.python_library(
name = "static_attention",
srcs = [
"static_attention.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama",
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
":llama_transformer",
"//caffe2:torch",
"//executorch/extension/llm/modeling/text_decoder:text_decoder_static_attention",
],
)

Expand All @@ -67,7 +61,7 @@ runtime.python_library(
"//caffe2:torch",
"//executorch/examples/models:model_base",
"//executorch/examples/models/llama:llama_transformer",
"//executorch/examples/models/llama/config:llm_config",
"//executorch/extension/llm/export/config:llm_config",
"//executorch/examples/models:checkpoint",
],
)
Expand Down Expand Up @@ -150,7 +144,7 @@ runtime.python_library(
":source_transformation",
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
"//caffe2:torch",
"//executorch/examples/models/llama/config:llm_config",
"//executorch/extension/llm/export/config:llm_config",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/exir/passes:init_mutable_pass",
"//executorch/examples/models:model_base",
Expand Down
Loading
Loading