-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
371 lines (306 loc) · 11.7 KB
/
Copy pathdata.py
File metadata and controls
371 lines (306 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
"""Data utilities and PyTorch Lightning DataModule for SMILES datasets.
This module provides helpers to load SMILES files, a simple
`AtomLevelTokenizer`, dataset and datamodule implementations used by the
training and evaluation scripts.
"""
import re
import os
import json
import torch
import torch.nn as nn
from tqdm import tqdm
from lightning import LightningDataModule
from typing import List, LiteralString, Optional, Self, Callable
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
RE_PATTERN = re.compile(r"(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])")
BOS, EOS, PAD, UNK, MSK = '<bos>', '<eos>', '<pad>', '<unk>', '<msk>'
def load_smiles(file_path: str, data_dir: str = "", max_samples: Optional[int] = None) -> List[str]:
"""Load SMILES strings from a text file.
Parameters
----------
file_path : str
Relative path to the file inside `data_dir`, e.g. ``chemblv31/train.txt``.
data_dir : str, optional
Root data directory to prepend to ``file_path`` (default: "").
max_samples : int or None, optional
If provided, stop after reading this many samples (default: None).
Returns
-------
List[str]
List of SMILES strings (one per non-empty line).
"""
full_path = os.path.join(data_dir, file_path) if data_dir else file_path
smiles_list: List[str] = []
with open(full_path, 'r') as f:
for line in f:
line = line.strip()
if line:
smiles_list.append(line)
if max_samples and len(smiles_list) >= max_samples:
break
return smiles_list
class AtomLevelTokenizer:
"""Simple atom-level SMILES tokenizer that maps tokens to ids.
The tokenizer discovers tokens from data (via ``from_data``) or can be
created from an explicit token list. It provides conversion helpers
between SMILES strings and integer tensors.
"""
def __init__(self, tokens: int = 128) -> None:
"""Initialize tokenizer with token set.
Parameters
----------
tokens : int or iterable
Iterable of token strings or an integer placeholder for sizing.
"""
all_tokens = [BOS, EOS, PAD, UNK, MSK] + sorted(list(tokens))
self.c2i = {c: i for i, c in enumerate(all_tokens)}
self.i2c = {i: c for i, c in enumerate(all_tokens)}
self.vocab_size = len(self.c2i)
self.BOS = BOS
self.EOS = EOS
self.PAD = PAD
self.UNK = UNK
self.MSK = MSK
@property
def bos(self) -> int: return self.c2i[BOS]
@property
def eos(self) -> int: return self.c2i[EOS]
@property
def pad(self) -> int: return self.c2i[PAD]
@property
def unk(self) -> int: return self.c2i[UNK]
@property
def mask(self) -> int: return self.c2i[MSK]
def string2tensor(self, string, add_bos=True, add_eos=True) -> torch.Tensor:
"""Convert a SMILES string to a tensor of token ids.
Parameters
----------
string : str
SMILES string to tokenize.
add_bos : bool, optional
Whether to prepend the beginning-of-sequence token (default: True).
add_eos : bool, optional
Whether to append the end-of-sequence token (default: True).
Returns
-------
torch.Tensor
1D tensor of token ids.
"""
ids = [self.c2i.get(t, self.unk) for t in RE_PATTERN.findall(string.strip())]
if add_bos: ids = [self.bos] + ids
if add_eos: ids = ids + [self.eos]
return torch.tensor(ids, dtype=torch.long)
def tensor2string(self, tensor, rem_bos=True, rem_eos=True) -> LiteralString:
"""Convert a tensor of token ids back to a SMILES string.
Parameters
----------
tensor : torch.Tensor
1D tensor of token ids.
rem_bos : bool, optional
Remove the leading BOS token if present (default: True).
rem_eos : bool, optional
Truncate at the EOS token if present (default: True).
Returns
-------
str
Reconstructed SMILES string.
"""
ids = tensor.tolist()
if rem_bos and ids and ids[0] == self.bos:
ids = ids[1:]
if rem_eos and self.eos in ids:
ids = ids[:ids.index(self.eos)]
ids = [i for i in ids if i not in (self.pad, self.mask)]
return ''.join([self.i2c.get(i, UNK) for i in ids])
def save(self, filepath: str) -> None:
"""Save tokenizer vocabulary mapping to JSON.
Parameters
----------
filepath : str
Output JSON file path.
"""
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(self.c2i, f, ensure_ascii=False, indent=2)
@classmethod
def load(cls, filepath: str) -> Self:
"""Load tokenizer from a saved vocabulary JSON file.
Parameters
----------
filepath : str
Path to JSON file produced by :meth:`save`.
Returns
-------
AtomLevelTokenizer
Tokenizer instance reconstructed from the saved vocabulary.
"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
return cls([k for k in sorted(data, key=data.get) if k not in (BOS, EOS, PAD, UNK, MSK)])
@classmethod
def from_data(cls, data) -> Self:
"""Construct tokenizer from iterable of SMILES strings.
Parameters
----------
data : Iterable[str]
Iterable over SMILES strings used to build the token set.
Returns
-------
AtomLevelTokenizer
Tokenizer built from observed tokens in ``data``.
"""
tokens = set()
for smiles in data:
tokens.update(RE_PATTERN.findall(smiles.strip()))
return cls(tokens)
def collate_fn(tokenizer):
"""Return a collate function bound to ``tokenizer`` for DataLoader.
The returned function pads sequences, constructs the padding mask and
optionally returns labels if provided by the dataset.
"""
def collate(batch):
"""Inner collate function used by DataLoader.
Parameters
----------
batch : Sequence[Tuple[str, Optional[int]]]
Batch of (smiles, label) tuples.
Returns
-------
dict
Dictionary with keys ``x`` (padded token ids), ``m`` (mask) and
optionally ``labels``.
"""
smiles_list, labels_list = zip(*batch)
tensors = [tokenizer.string2tensor(s) for s in smiles_list]
lengths = [len(t) for t in tensors]
x = nn.utils.rnn.pad_sequence(tensors, batch_first=True, padding_value=tokenizer.pad)
m = torch.arange(x.size(1)) < torch.tensor(lengths).unsqueeze(1)
result = {'x': x, 'm': m}
if labels_list[0] is not None:
result['labels'] = torch.tensor(labels_list, dtype=torch.long)
return result
return collate
class SMILESDataset(Dataset):
"""Minimal Dataset wrapper for SMILES lists.
Parameters
----------
smiles : List[str]
List of SMILES strings.
labels : Optional[Sequence], optional
Optional labels associated with each SMILES (default: None).
"""
def __init__(self, smiles: List[str], labels=None):
self.smiles = smiles
self.labels = labels
def __len__(self):
return len(self.smiles)
def __getitem__(self, idx):
label = self.labels[idx] if self.labels is not None else None
return self.smiles[idx], label
class SMILESDataModule(LightningDataModule):
"""PyTorch Lightning DataModule for SMILES pretraining datasets.
This DataModule expects a directory structure under ``data_dir`` such
that each dataset name contains ``train.txt``, ``valid.txt`` and
``test.txt`` files.
"""
def __init__(
self,
dataset: str = "chemblv31",
data_dir: str = "",
batch_size: int = 64,
num_workers: int = 16,
pin_memory: bool = False,
max_samples: Optional[int] = None,
collate_fn: Callable[["AtomLevelTokenizer"], Callable] = collate_fn,
split_ratio: float = 0.8,
**kwargs
) -> None:
super().__init__()
self.dataset = dataset
self.data_dir = data_dir
self.batch_size_per_device = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.max_samples = max_samples
self.collate_fn = collate_fn
self.split_ratio = split_ratio
self.data_trn = self.data_val = self.data_tst = None
self._prepare_and_setup()
def _prepare_and_setup(self):
"""Load dataset files and populate train/val/test Dataset objects.
Assumes the dataset folder contains the files
``train.txt``, ``valid.txt`` and ``test.txt``.
"""
trn_smiles = load_smiles(f'{self.dataset}/train.txt', self.data_dir, self.max_samples)
val_smiles = load_smiles(f'{self.dataset}/valid.txt', self.data_dir, self.max_samples)
tst_smiles = load_smiles(f'{self.dataset}/test.txt', self.data_dir, self.max_samples)
self.data_trn = SMILESDataset(trn_smiles)
self.data_val = SMILESDataset(val_smiles)
self.data_tst = SMILESDataset(tst_smiles)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.data_trn,
batch_size=self.batch_size_per_device,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=True,
collate_fn=self.collate_fn
)
def val_dataloader(self) -> Optional[DataLoader]:
if self.data_val is None:
return None
return DataLoader(
self.data_val,
batch_size=self.batch_size_per_device,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
collate_fn=self.collate_fn
)
def test_dataloader(self) -> Optional[DataLoader]:
if self.data_tst is None:
return None
return DataLoader(
self.data_tst,
batch_size=self.batch_size_per_device,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
collate_fn=self.collate_fn
)
def process_dataset(dataset_name: str) -> int:
"""Process dataset: create tokenizer, save vocab, and find max sequence length.
Args:
dataset_name: Name of the dataset (e.g., 'chemblv31', 'coconut')
Returns:
Maximum sequence length after tokenization
"""
data_dir = os.path.join('data', dataset_name)
files = ['train.txt', 'valid.txt', 'test.txt']
# Load all SMILES
smiles = []
for fname in files:
file_path = os.path.join(data_dir, fname)
if os.path.exists(file_path):
smiles.extend(load_smiles(fname, data_dir))
# Create tokenizer from data
tokenizer = AtomLevelTokenizer.from_data(smiles)
# Save vocabulary
out_path = os.path.join(data_dir, 'vocab.json')
tokenizer.save(out_path)
print(f'Saved vocabulary to {out_path}')
# Tokenize all SMILES and find max length
max_length = 0
for smiles_str in tqdm(smiles, desc=f'Tokenizing {dataset_name}'):
tensor = tokenizer.string2tensor(smiles_str, add_bos=True, add_eos=True)
max_length = max(max_length, len(tensor))
return max_length
def main() -> None:
datasets = ['guacamol','chemblv31']
for dataset_name in datasets:
max_len = process_dataset(dataset_name)
print(f'{dataset_name}: max sequence length = {max_len}')
print()
# chemblv31: max sequence length = 100
if __name__ == '__main__':
main()