Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clip_benchmark/metrics/image_caption_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True):

dict of accuracy metrics
"""
autocast = torch.cuda.amp.autocast if amp else suppress
autocast = torch.amp.autocast('cuda') if amp else suppress()
image_score = []
text_score = []
score = []
Expand All @@ -52,7 +52,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True):
# tokenize all texts in the batch
batch_texts_tok_ = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device)
# compute the embedding of images and texts
with torch.no_grad(), autocast():
with torch.no_grad(), autocast:
batch_images_emb = F.normalize(model.encode_image(batch_images_), dim=-1).view(B, nim, -1)
batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok_), dim=-1).view(B, nt, -1)
gt = torch.arange(min(nim, nt)).to(device)
Expand Down
8 changes: 4 additions & 4 deletions clip_benchmark/metrics/linear_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autoc
scheduler(step)

optimizer.zero_grad()
with autocast():
with autocast:
pred = model(x)
loss = criterion(pred, y)

Expand Down Expand Up @@ -114,7 +114,7 @@ def infer(model, dataloader, autocast, device):
x = x.to(device)
y = y.to(device)

with autocast():
with autocast:
logits = model(x)

pred.append(logits.cpu())
Expand Down Expand Up @@ -150,7 +150,7 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
os.mkdir(feature_dir)

featurizer = Featurizer(model, normalize).cuda()
autocast = torch.cuda.amp.autocast if amp else suppress
autocast = torch.amp.autocast('cuda') if amp else suppress()
if not os.path.exists(os.path.join(feature_dir, 'targets_train.pt')):
# now we have to cache the features
devices = [x for x in range(torch.cuda.device_count())]
Expand All @@ -168,7 +168,7 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
for images, target in tqdm(loader):
images = images.to(device)

with autocast():
with autocast:
feature = featurizer(images)

features.append(feature.cpu())
Expand Down
9 changes: 5 additions & 4 deletions clip_benchmark/metrics/zeroshot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.metrics import classification_report, balanced_accuracy_score



def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=True):
"""
This function returns zero-shot vectors for each class in order
Expand All @@ -36,8 +37,8 @@ def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=Tr
torch.Tensor of shape (N,C) where N is the number
of templates, and C is the number of classes.
"""
autocast = torch.cuda.amp.autocast if amp else suppress
with torch.no_grad(), autocast():
autocast = torch.amp.autocast('cuda') if amp else suppress()
with torch.no_grad(), autocast:
zeroshot_weights = []
for classname in tqdm(classnames):
if type(templates) == dict:
Expand Down Expand Up @@ -100,7 +101,7 @@ def run_classification(model, classifier, dataloader, device, amp=True):
- pred (N, C) are the logits
- true (N,) are the actual classes
"""
autocast = torch.cuda.amp.autocast if amp else suppress
autocast = torch.amp.autocast('cuda') if amp else suppress()
pred = []
true = []
nb = 0
Expand All @@ -109,7 +110,7 @@ def run_classification(model, classifier, dataloader, device, amp=True):
images = images.to(device)
target = target.to(device)

with autocast():
with autocast:
# predict
image_features = model.encode_image(images)
image_features = F.normalize(image_features, dim=-1)
Expand Down
4 changes: 2 additions & 2 deletions clip_benchmark/metrics/zeroshot_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5])
# for each text, we collect the corresponding image index, as each image can have multiple corresponding texts
texts_image_index = []
dataloader = dataloader_with_indices(dataloader)
autocast = torch.cuda.amp.autocast if amp else suppress
autocast = torch.amp.autocast('cuda') if amp else suppress()
for batch_images, batch_texts, inds in tqdm(dataloader):
batch_images = batch_images.to(device)
# tokenize all texts in the batch
Expand All @@ -49,7 +49,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5])
batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts]

# compute the embedding of images and texts
with torch.no_grad(), autocast():
with torch.no_grad(), autocast:
batch_images_emb = F.normalize(model.encode_image(batch_images), dim=-1)
batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok), dim=-1)

Expand Down