Skip to content

Commit 94cf62f

Browse files
authored
Merge pull request #1217 from PyThaiNLP/copilot/add-qwen3-0-6b-model
Add Qwen3-0.6B language model support to pythainlp.lm and improve type annotations
2 parents 05b0ceb + f529a9b commit 94cf62f

File tree

8 files changed

+365
-13
lines changed

8 files changed

+365
-13
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ onnx = ["numpy>=1.22", "onnxruntime>=1.10.0", "sentencepiece>=0.1.91"]
124124

125125
oskut = ["oskut>=1.3"]
126126

127+
qwen3 = ["torch>=1.9.0", "transformers>=4.22.1"]
128+
127129
sefr_cut = ["sefr_cut>=1.1"]
128130

129131
spacy_thai = ["spacy_thai>=0.7.1"]

pythainlp/chat/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from pythainlp.generate.wangchanglm import WangChanGLM
1212

13-
1413
class ChatBotModel:
1514
history: list[tuple[str, str]]
1615
model: "WangChanGLM"
@@ -39,7 +38,7 @@ def load_model(
3938
:param bool return_dict: return_dict
4039
:param bool load_in_8bit: load model in 8bit
4140
:param str device: device (cpu, cuda or other)
42-
:param torch_dtype torch_dtype: torch_dtype
41+
:param Optional[torch.dtype] torch_dtype: torch_dtype
4342
:param str offload_folder: offload folder
4443
:param bool low_cpu_mem_usage: low cpu mem usage
4544
"""

pythainlp/generate/wangchanglm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def load_model(
5454
:param bool return_dict: return dict
5555
:param bool load_in_8bit: load model in 8bit
5656
:param str device: device (cpu, cuda or other)
57-
:param torch_dtype torch_dtype: torch_dtype
57+
:param Optional[torch.dtype] torch_dtype: torch_dtype
5858
:param str offload_folder: offload folder
5959
:param bool low_cpu_mem_usage: low cpu mem usage
6060
"""

pythainlp/lm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# SPDX-FileType: SOURCE
33
# SPDX-License-Identifier: Apache-2.0
44

5-
__all__: list[str] = ["calculate_ngram_counts", "remove_repeated_ngrams"]
5+
__all__: list[str] = ["calculate_ngram_counts", "remove_repeated_ngrams", "Qwen3"]
66

7+
from pythainlp.lm.qwen3 import Qwen3
78
from pythainlp.lm.text_util import (
89
calculate_ngram_counts,
910
remove_repeated_ngrams,

pythainlp/lm/qwen3.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# SPDX-FileCopyrightText: 2016-2026 PyThaiNLP Project
2+
# SPDX-FileType: SOURCE
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from __future__ import annotations
6+
7+
from typing import TYPE_CHECKING, Any, Optional
8+
9+
if TYPE_CHECKING:
10+
import torch
11+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
12+
13+
14+
class Qwen3:
15+
"""Qwen3-0.6B language model for Thai text generation.
16+
17+
A small but capable language model from Alibaba Cloud's Qwen family,
18+
optimized for various NLP tasks including Thai language processing.
19+
"""
20+
21+
def __init__(self) -> None:
22+
self.model: Optional["PreTrainedModel"] = None
23+
self.tokenizer: Optional["PreTrainedTokenizerBase"] = None
24+
self.device: Optional[str] = None
25+
self.torch_dtype: Optional["torch.dtype"] = None
26+
self.model_path: Optional[str] = None
27+
28+
def load_model(
29+
self,
30+
model_path: str = "Qwen/Qwen3-0.6B",
31+
device: str = "cuda",
32+
torch_dtype: Optional["torch.dtype"] = None,
33+
low_cpu_mem_usage: bool = True,
34+
) -> None:
35+
"""Load Qwen3 model.
36+
37+
:param str model_path: model path or HuggingFace model ID
38+
:param str device: device (cpu, cuda or other)
39+
:param Optional[torch.dtype] torch_dtype: torch data type (e.g., torch.float16, torch.bfloat16)
40+
:param bool low_cpu_mem_usage: low cpu mem usage
41+
42+
:Example:
43+
::
44+
45+
from pythainlp.lm import Qwen3
46+
import torch
47+
48+
model = Qwen3()
49+
model.load_model(device="cpu", torch_dtype=torch.bfloat16)
50+
"""
51+
try:
52+
import torch
53+
from transformers import AutoModelForCausalLM, AutoTokenizer
54+
except (ImportError, ModuleNotFoundError) as exc:
55+
raise ImportError(
56+
"Qwen3 language model requires optional dependencies. "
57+
"Install them with: pip install 'pythainlp[qwen3]'"
58+
) from exc
59+
60+
# Set default torch_dtype if not provided
61+
if torch_dtype is None:
62+
torch_dtype = torch.float16
63+
64+
# Check CUDA availability early before loading model
65+
if device.startswith("cuda"):
66+
if not torch.cuda.is_available():
67+
raise RuntimeError(
68+
"CUDA device requested but CUDA is not available. "
69+
"Check your PyTorch installation and GPU drivers, or use "
70+
"device='cpu' instead."
71+
)
72+
73+
self.device = device
74+
self.torch_dtype = torch_dtype
75+
self.model_path = model_path
76+
77+
try:
78+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
79+
except OSError as exc:
80+
raise RuntimeError(
81+
f"Failed to load tokenizer from '{self.model_path}'. "
82+
"Check the model path or your network connection."
83+
) from exc
84+
85+
try:
86+
self.model = AutoModelForCausalLM.from_pretrained(
87+
self.model_path,
88+
device_map=device,
89+
torch_dtype=torch_dtype,
90+
low_cpu_mem_usage=low_cpu_mem_usage,
91+
)
92+
except OSError as exc:
93+
# Clean up tokenizer on failure
94+
self.tokenizer = None
95+
raise RuntimeError(
96+
f"Failed to load model from '{self.model_path}'. "
97+
"This can happen due to an invalid model path, missing files, "
98+
"or insufficient disk space."
99+
) from exc
100+
except Exception as exc:
101+
# Clean up tokenizer on failure
102+
self.tokenizer = None
103+
raise RuntimeError(
104+
f"Failed to load model weights: {exc}. "
105+
"This can be caused by insufficient memory, an incompatible "
106+
"torch_dtype setting, or other configuration issues."
107+
) from exc
108+
109+
def generate(
110+
self,
111+
text: str,
112+
max_new_tokens: int = 512,
113+
temperature: float = 0.7,
114+
top_p: float = 0.9,
115+
top_k: int = 50,
116+
do_sample: bool = True,
117+
skip_special_tokens: bool = True,
118+
) -> str:
119+
"""Generate text from a prompt.
120+
121+
:param str text: input text prompt
122+
:param int max_new_tokens: maximum number of new tokens to generate
123+
:param float temperature: temperature for sampling (higher = more random)
124+
:param float top_p: top p for nucleus sampling
125+
:param int top_k: top k for top-k sampling
126+
:param bool do_sample: whether to use sampling or greedy decoding
127+
:param bool skip_special_tokens: skip special tokens in output
128+
:return: generated text
129+
:rtype: str
130+
131+
:Example:
132+
::
133+
134+
from pythainlp.lm import Qwen3
135+
import torch
136+
137+
model = Qwen3()
138+
model.load_model(device="cpu", torch_dtype=torch.bfloat16)
139+
140+
result = model.generate("สวัสดี")
141+
print(result)
142+
"""
143+
if self.model is None or self.tokenizer is None or self.device is None:
144+
raise RuntimeError(
145+
"Model not loaded. Please call load_model() first."
146+
)
147+
148+
if not text or not isinstance(text, str):
149+
raise ValueError(
150+
"text parameter must be a non-empty string."
151+
)
152+
153+
try:
154+
import torch
155+
except (ImportError, ModuleNotFoundError) as exc:
156+
raise ImportError(
157+
"Qwen3 language model requires optional dependencies. "
158+
"Install them with: pip install 'pythainlp[qwen3]'"
159+
) from exc
160+
161+
inputs = self.tokenizer(text, return_tensors="pt")
162+
input_ids = inputs["input_ids"].to(self.device)
163+
164+
# Note: When do_sample=False (greedy decoding), temperature, top_p,
165+
# and top_k parameters are ignored by the transformers library
166+
with torch.inference_mode():
167+
output_ids = self.model.generate(
168+
input_ids,
169+
max_new_tokens=max_new_tokens,
170+
temperature=temperature,
171+
top_p=top_p,
172+
top_k=top_k,
173+
do_sample=do_sample,
174+
)
175+
176+
# Decode only the newly generated tokens
177+
# output_ids and input_ids are guaranteed to be 2D tensors with
178+
# batch size 1 from the tokenizer call above
179+
generated_text = self.tokenizer.decode(
180+
output_ids[0][len(input_ids[0]) :],
181+
skip_special_tokens=skip_special_tokens,
182+
)
183+
184+
return generated_text
185+
186+
def chat(
187+
self,
188+
messages: list[dict[str, Any]],
189+
max_new_tokens: int = 512,
190+
temperature: float = 0.7,
191+
top_p: float = 0.9,
192+
top_k: int = 50,
193+
do_sample: bool = True,
194+
skip_special_tokens: bool = True,
195+
) -> str:
196+
"""Generate text using chat format.
197+
198+
:param list[dict[str, Any]] messages: list of message dictionaries with 'role' and 'content' keys
199+
:param int max_new_tokens: maximum number of new tokens to generate
200+
:param float temperature: temperature for sampling
201+
:param float top_p: top p for nucleus sampling
202+
:param int top_k: top k for top-k sampling
203+
:param bool do_sample: whether to use sampling
204+
:param bool skip_special_tokens: skip special tokens in output
205+
:return: generated response
206+
:rtype: str
207+
208+
:Example:
209+
::
210+
211+
from pythainlp.lm import Qwen3
212+
import torch
213+
214+
model = Qwen3()
215+
model.load_model(device="cpu", torch_dtype=torch.bfloat16)
216+
217+
messages = [{"role": "user", "content": "สวัสดีครับ"}]
218+
response = model.chat(messages)
219+
print(response)
220+
"""
221+
if self.model is None or self.tokenizer is None or self.device is None:
222+
raise RuntimeError(
223+
"Model not loaded. Please call load_model() first."
224+
)
225+
226+
if not messages or not isinstance(messages, list):
227+
raise ValueError(
228+
"messages parameter must be a non-empty list of message dictionaries."
229+
)
230+
231+
# Apply chat template if available, otherwise format manually
232+
if hasattr(self.tokenizer, "apply_chat_template"):
233+
text = self.tokenizer.apply_chat_template(
234+
messages,
235+
tokenize=False,
236+
add_generation_prompt=True,
237+
)
238+
else:
239+
# Simple fallback format - preserve content newlines
240+
lines = []
241+
for msg in messages:
242+
role = str(msg.get("role", "user")).replace("\n", " ")
243+
content = str(msg.get("content", ""))
244+
lines.append(f"{role}: {content}")
245+
text = "\n".join(lines) + "\nassistant: "
246+
247+
try:
248+
import torch
249+
except (ImportError, ModuleNotFoundError) as exc:
250+
raise ImportError(
251+
"Qwen3 language model requires optional dependencies. "
252+
"Install them with: pip install 'pythainlp[qwen3]'"
253+
) from exc
254+
255+
inputs = self.tokenizer(text, return_tensors="pt")
256+
input_ids = inputs["input_ids"].to(self.device)
257+
258+
# Note: When do_sample=False (greedy decoding), temperature, top_p,
259+
# and top_k parameters are ignored by the transformers library
260+
with torch.inference_mode():
261+
output_ids = self.model.generate(
262+
input_ids,
263+
max_new_tokens=max_new_tokens,
264+
temperature=temperature,
265+
top_p=top_p,
266+
top_k=top_k,
267+
do_sample=do_sample,
268+
)
269+
270+
# Decode only the newly generated tokens
271+
# output_ids and input_ids are guaranteed to be 2D tensors with
272+
# batch size 1 from the tokenizer call above
273+
generated_text = self.tokenizer.decode(
274+
output_ids[0][len(input_ids[0]) :],
275+
skip_special_tokens=skip_special_tokens,
276+
)
277+
278+
return generated_text

pythainlp/phayathaibert/core.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@
1010
from typing import TYPE_CHECKING, Union
1111

1212
if TYPE_CHECKING:
13-
from transformers import CamembertTokenizer
14-
from transformers.pipelines.base import Pipeline
13+
from transformers import (
14+
AutoModelForMaskedLM,
15+
AutoModelForTokenClassification,
16+
CamembertTokenizer,
17+
Pipeline,
18+
PreTrainedTokenizerBase,
19+
)
1520

1621
from transformers import (
1722
CamembertTokenizer,
@@ -212,13 +217,13 @@ def __init__(self) -> None:
212217
pipeline,
213218
)
214219

215-
self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
220+
self.tokenizer: "PreTrainedTokenizerBase" = AutoTokenizer.from_pretrained(
216221
_model_name
217222
)
218-
self.model_for_masked_lm: AutoModelForMaskedLM = (
223+
self.model_for_masked_lm: "AutoModelForMaskedLM" = (
219224
AutoModelForMaskedLM.from_pretrained(_model_name)
220225
)
221-
self.model: "Pipeline" = pipeline(
226+
self.model: "Pipeline" = pipeline( # transformers.Pipeline
222227
"fill-mask",
223228
tokenizer=self.tokenizer,
224229
model=self.model_for_masked_lm,
@@ -311,8 +316,8 @@ def __init__(self, model: str = "lunarlist/pos_thai_phayathai") -> None:
311316
AutoTokenizer,
312317
)
313318

314-
self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model)
315-
self.model: AutoModelForTokenClassification = (
319+
self.tokenizer: "PreTrainedTokenizerBase" = AutoTokenizer.from_pretrained(model)
320+
self.model: "AutoModelForTokenClassification" = (
316321
AutoModelForTokenClassification.from_pretrained(model)
317322
)
318323

@@ -356,8 +361,8 @@ def __init__(self, model: str = "Pavarissy/phayathaibert-thainer") -> None:
356361
AutoTokenizer,
357362
)
358363

359-
self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model)
360-
self.model: AutoModelForTokenClassification = (
364+
self.tokenizer: "PreTrainedTokenizerBase" = AutoTokenizer.from_pretrained(model)
365+
self.model: "AutoModelForTokenClassification" = (
361366
AutoModelForTokenClassification.from_pretrained(model)
362367
)
363368

tests/noauto_torch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
# Names of module to be tested
2525
test_packages: list[str] = [
26+
"tests.noauto_torch.testn_lm_torch",
2627
"tests.noauto_torch.testn_spell_torch",
2728
"tests.noauto_torch.testn_tag_torch",
2829
"tests.noauto_torch.testn_tokenize_torch",

0 commit comments

Comments
 (0)