| | import transformers |
| | from transformers import AutoTokenizer, AutoModelForMaskedLM |
| | import logging |
| | import torch |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | import numpy as np |
| | import gradio as gr |
| | from io import BytesIO |
| | from PIL import Image |
| |
|
| | def get_heatmap(sequence): |
| | logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | |
| | model_name = "ChatterjeeLab/FusOn-pLM" |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| | model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) |
| | model.to(device) |
| | model.eval() |
| |
|
| | all_logits = [] |
| | for i in range(len(sequence)): |
| | |
| | masked_seq = sequence[:i] + '<mask>' + sequence[i+1:] |
| |
|
| | |
| | inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000) |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
|
| | |
| | with torch.no_grad(): |
| | logits = model(**inputs).logits |
| | mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
| | mask_token_logits = logits[0, mask_token_index, :] |
| | top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item() |
| | logits_array = mask_token_logits.cpu().numpy() |
| |
|
| | |
| | filtered_indices = list(range(4, 23 + 1)) |
| | filtered_logits = logits_array[:, filtered_indices] |
| | all_logits.append(filtered_logits) |
| |
|
| | token_indices = torch.arange(logits.size(-1)) |
| | tokens = [tokenizer.decode([idx]) for idx in token_indices] |
| | filtered_tokens = [tokens[i] for i in filtered_indices] |
| |
|
| | all_logits_array = np.vstack(all_logits) |
| | normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min()) |
| | transposed_logits_array = normalized_logits_array.T |
| |
|
| |
|
| |
|
| | |
| | step = 50 |
| | y_tick_positions = np.arange(0, len(sequence), step) |
| | y_tick_labels = [str(pos) for pos in y_tick_positions] |
| |
|
| | plt.figure(figsize=(15, 8)) |
| | sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens) |
| | plt.title('Logits for masked per residue tokens') |
| | plt.ylabel('Token') |
| | plt.xlabel('Residue Index') |
| | plt.yticks(rotation=0) |
| | plt.xticks(y_tick_positions, y_tick_labels, rotation = 0) |
| |
|
| | |
| | buf = BytesIO() |
| | plt.savefig(buf, format='png') |
| | buf.seek(0) |
| | plt.close() |
| | |
| | |
| | img = Image.open(buf) |
| | return img |
| |
|
| |
|
| | demo = gr.Interface(fn=get_heatmap, inputs="text", outputs="image") |
| |
|
| | demo.launch() |