In [2]:
import torch
import pandas as pd
from torch import optim, nn
from torch.utils.data import Dataset, DataLoader
from evo2 import Evo2
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load Evo2 model
evo2_model = Evo2('evo2_7b')
tokenizer = evo2_model.tokenizer

Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 2739.14it/s]


Found complete file in repo: evo2_7b.pt


100%|██████████| 32/32 [00:01<00:00, 23.08it/s]

DEBUG: About to load checkpoint from /home/mk5636/.cache/huggingface/hub/models--arcinstitute--evo2_7b/snapshots/ab55f718e990600e4f7c9992c313d97e499a20dc/evo2_7b.pt
Loading checkpoint from: /home/mk5636/.cache/huggingface/hub/models--arcinstitute--evo2_7b/snapshots/ab55f718e990600e4f7c9992c313d97e499a20dc/evo2_7b.pt





Extra keys in state_dict: {'blocks.24.mixer.attn._extra_state', 'blocks.13.mixer.mixer.filter.t', 'blocks.3.mixer.attn._extra_state', 'blocks.24.mixer.dense._extra_state', 'blocks.16.mixer.mixer.filter.t', 'blocks.17.mixer.dense._extra_state', 'blocks.10.mixer.attn._extra_state', 'blocks.10.mixer.dense._extra_state', 'blocks.27.mixer.mixer.filter.t', 'blocks.23.mixer.mixer.filter.t', 'blocks.3.mixer.dense._extra_state', 'blocks.9.mixer.mixer.filter.t', 'blocks.30.mixer.mixer.filter.t', 'blocks.2.mixer.mixer.filter.t', 'blocks.20.mixer.mixer.filter.t', 'unembed.weight', 'blocks.17.mixer.attn._extra_state', 'blocks.6.mixer.mixer.filter.t', 'blocks.31.mixer.dense._extra_state', 'blocks.31.mixer.attn._extra_state'}


  state = torch.load(state, map_location="cuda")
  return torch_load(state, map_location=device)


In [3]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Load data
train_df = pd.read_csv("train_df.csv")
test_df = pd.read_csv("test_df.csv")

# Remove 'N' (padding) from fugu sequences
train_df['seq_fugu'] = train_df['seq_fugu'].str.replace('N', '', regex=False)
test_df['seq_fugu'] = test_df['seq_fugu'].str.replace('N', '', regex=False)

In [4]:
# Toy data
toy_data = pd.DataFrame({
    "seq_human": ["ACTG", "GATTACA", "CGT"],
    "seq_fugu":  ["TGAC", "CTAATGT", "GCA"]  # complement strand
})
train_df = toy_data

In [11]:
import random

def random_dna_sequence(length=10):
    # Generate a random sequence using A, C, T, G
    return ''.join(random.choices("ACTG", k=length))

def complement(seq):
    # Define the complement mapping for DNA
    mapping = {"A": "T", "T": "A", "C": "G", "G": "C"}
    # Compute the complement for the sequence
    return ''.join(mapping[base] for base in seq)

# Generate 5000 examples
num_examples = 5000
data = {"seq_human": [], "seq_fugu": []}

for _ in range(num_examples):
    seq = random_dna_sequence(10)
    comp = complement(seq)
    data["seq_human"].append(seq)
    data["seq_fugu"].append(comp)

toy_data = pd.DataFrame(data)
train_df = toy_data

print(toy_data.head())


    seq_human    seq_fugu
0  TAGTTGGACC  ATCAACCTGG
1  CCGTACCTCT  GGCATGGAGA
2  GTCAATGTCG  CAGTTACAGC
3  TCCTCTTTCC  AGGAGAAAGG
4  CCACGCCTTC  GGTGCGGAAG


In [5]:
# Tokenization + masking function
def tokenize_and_mask(human, fugu, sep="#", eos=tokenizer.eos):
    # Format full string
    full_seq = human + sep + fugu + chr(eos)
    
    # Tokenize
    tokens = tokenizer.tokenize(full_seq)

    # Find separator position
    sep_token = ord(sep)  # '#' → 35
    sep_index = tokens.index(sep_token)

    # Build loss mask: 0 before/including '#', 1 after
    loss_mask = [0] * (sep_index + 1) + [1] * (len(tokens) - (sep_index + 1))

    return {
        "input_ids": tokens,
        "loss_mask": loss_mask
    }

# Dataset class
class Evo2Dataset(Dataset):
    def __init__(self, dataframe):
        self.df = dataframe

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        tokenized = tokenize_and_mask(row['seq_human'], row['seq_fugu'])

        return {
            "input_ids": torch.tensor(tokenized["input_ids"], dtype=torch.long),
            "loss_mask": torch.tensor(tokenized["loss_mask"], dtype=torch.float),
        }

# Collate function for padding
def collate_fn(batch):
    input_ids = [item["input_ids"] for item in batch]
    loss_masks = [item["loss_mask"] for item in batch]

    # Pad input IDs
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=tokenizer.pad_id
    )

    # Pad loss masks (0 means ignore)
    loss_masks = torch.nn.utils.rnn.pad_sequence(
        loss_masks, batch_first=True, padding_value=0
    )

    return {
        "input_ids": input_ids,
        "attention_mask": (input_ids != tokenizer.pad_id).long(),
        "labels": input_ids.clone(),
        "loss_mask": loss_masks
    }

In [13]:
# Create DataLoader
train_dataset = Evo2Dataset(train_df)

train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn,
    pin_memory=True
)

In [7]:
import yaml
from types import SimpleNamespace

# Load the YAML file
with open("/scratch/mk5636/DS-GA1010/evo2/evo2-7b-1m.yml", "r") as f:
    config = yaml.safe_load(f)

# Assuming 'config' is the dictionary you loaded from the YAML
evo2_model.model.config = SimpleNamespace(**config)

# Define a method to allow 'get' functionality for SimpleNamespace
def get_from_config(self, key, default=None):
    return getattr(self, key, default)

# Add the get method to the config object
evo2_model.model.config.get = get_from_config.__get__(evo2_model.model.config)


In [8]:
def prepare_inputs_for_generation(input_ids, **kwargs):
    return {"input_ids": input_ids}

evo2_model.model.prepare_inputs_for_generation = prepare_inputs_for_generation

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["Wqkv", "out_proj", "l1", "l2", "l3"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# Create LoRA-wrapped model
from peft.tuners.lora.model import LoraModel
model = LoraModel(evo2_model.model, lora_config, adapter_name="default")


# Move model to device
model.to(device)
model.train()

# Verify it worked
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

model.blocks.0.mlp.l1.lora_A.default.weight
model.blocks.0.mlp.l1.lora_B.default.weight
model.blocks.0.mlp.l2.lora_A.default.weight
model.blocks.0.mlp.l2.lora_B.default.weight
model.blocks.0.mlp.l3.lora_A.default.weight
model.blocks.0.mlp.l3.lora_B.default.weight
model.blocks.1.mlp.l1.lora_A.default.weight
model.blocks.1.mlp.l1.lora_B.default.weight
model.blocks.1.mlp.l2.lora_A.default.weight
model.blocks.1.mlp.l2.lora_B.default.weight
model.blocks.1.mlp.l3.lora_A.default.weight
model.blocks.1.mlp.l3.lora_B.default.weight
model.blocks.2.mlp.l1.lora_A.default.weight
model.blocks.2.mlp.l1.lora_B.default.weight
model.blocks.2.mlp.l2.lora_A.default.weight
model.blocks.2.mlp.l2.lora_B.default.weight
model.blocks.2.mlp.l3.lora_A.default.weight
model.blocks.2.mlp.l3.lora_B.default.weight
model.blocks.3.inner_mha_cls.Wqkv.lora_A.default.weight
model.blocks.3.inner_mha_cls.Wqkv.lora_B.default.weight
model.blocks.3.inner_mha_cls.out_proj.lora_A.default.weight
model.blocks.3.inner_mha_cls.out_pro

In [9]:
def replace_param(model, param_name, new_param):
    names = param_name.split('.')
    mod = model
    for name in names[:-1]:
        mod = getattr(mod, name)
    setattr(mod, names[-1], new_param)

def fix_inference_tensors(model):
    # Iterate over the parameters of the base model
    for name, param in list(model.model.named_parameters()):
        # Create a new parameter from a clone of the data
        new_param = torch.nn.Parameter(param.detach().clone(), requires_grad=param.requires_grad)
        replace_param(model.model, name, new_param)

# Apply the fix to ensure base model parameters are not inference tensors
fix_inference_tensors(model)


In [14]:
# Set up optimizer
optimizer = optim.AdamW(model.parameters(), lr=2e-5)

num_epochs = 3

for epoch in range(num_epochs):
    total_loss = 0
    progress = tqdm(train_loader, desc=f"Epoch {epoch+1}", dynamic_ncols=True)

    for batch in progress:
        input_ids = batch["input_ids"].to(device).clone().contiguous().to(torch.long)
        attention_mask = batch["attention_mask"].to(device).clone().contiguous()
        labels = batch["labels"].to(device)
        loss_mask = batch["loss_mask"].to(device)

        optimizer.zero_grad()

        model.train()
        model.model.stateful = False
        model.model.training = True

        # Forward pass
        with torch.enable_grad():
            outputs = model.model(
                x=input_ids,
                padding_mask=attention_mask
            )

        
        logits = outputs[0] # shape: [batch_size, seq_len, vocab_size], logits are raw predicted scores for each token

        # Manually compute masked loss
        shift_logits = logits[:, :-1, :].contiguous() # shape: [batch_size, seq_len-1, vocab_size], drop last token because it's not used for prediction (there is no next token)
        shift_labels = labels[:, 1:].contiguous() # shape: [batch_size, seq_len-1], drop first token because it's not predicted
        # now logits ane labels are alignd, ie. logits[0, 0] is the prediction for labels[0, 0] (which was labels[0, 1] in the original input)
        shift_mask = loss_mask[:, 1:].contiguous() # shape: [batch_size, seq_len-1], drop first position because it's not being predicted

        loss_fct = nn.CrossEntropyLoss(reduction="none") #reduction="none" to get per-token loss
        # per_token_loss = loss_fct(logits, labels) # this will compute loss for all tokens
        # .view(-1) rshapes/flattens the tensor to one vector, e.g. [batch_size, seq_len] -> [batch_size*seq_len]
        # .view(-1, 256) reshapes [batch_size, seq_len, 256] to [batch_size*seq_len, 256] (256 is vocab size)
        per_token_loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)), # shape: [batch_size*seq_len, vocab_size] e.g. [10, 256] (predicted scores for each position)
            shift_labels.view(-1) # shape: [batch_size*seq_len] e.g. [10] (true labels for each position)
        )

        # Apply loss mask
        # take the per-token loss values and multiply them by the mask so that only 
        # certain positions (after the '#' token) contribute to the total loss
        masked_loss = per_token_loss * shift_mask.view(-1) 
        loss = masked_loss.sum() / shift_mask.sum()

        # Backprop and update
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")

    model.eval()
    with torch.no_grad():
        for example in toy_data["seq_human"][:5]:
            # Create input prompt
            prompt = example + "#"

            output_tokens = []
            output_text = prompt

            evo2_model.model = model

            max_new_tokens = 10  # A generous upper bound
            output = evo2_model.generate(
                prompt_seqs=[prompt],
                n_tokens=max_new_tokens,
                temperature=0.01,
                batched=True,
                cached_generation=True,
                verbose=1,
            )

            print("Generated Sequence:", output.sequences[0])


Epoch 1: 100%|██████████| 5000/5000 [07:31<00:00, 11.08it/s, loss=0.118]


Epoch 1 average loss: 0.7853
Initializing inference params with max_seqlen=21
Prompt: "TAGTTGGACC#",	Output: "CTAAACCTGG",	Score: -0.13750103116035461
Generated Sequence: CTAAACCTGG
Initializing inference params with max_seqlen=21
Prompt: "CCGTACCTCT#",	Output: "CGCATGGAGA",	Score: -0.1476009339094162
Generated Sequence: CGCATGGAGA
Initializing inference params with max_seqlen=21
Prompt: "GTCAATGTCG#",	Output: "CAGTTACAGC",	Score: -0.12518373131752014
Generated Sequence: CAGTTACAGC
Initializing inference params with max_seqlen=21
Prompt: "TCCTCTTTCC#",	Output: "GGGAGAAAGG",	Score: -0.13489963114261627
Generated Sequence: GGGAGAAAGG
Initializing inference params with max_seqlen=21
Prompt: "CCACGCCTTC#",	Output: "GGTGCGGAAG",	Score: -0.12686794996261597
Generated Sequence: GGTGCGGAAG


Epoch 2: 100%|██████████| 5000/5000 [07:29<00:00, 11.12it/s, loss=0.131]  


Epoch 2 average loss: 0.1176
Initializing inference params with max_seqlen=21
Prompt: "TAGTTGGACC#",	Output: "TTAAACCTGG",	Score: -0.09572743624448776
Generated Sequence: TTAAACCTGG
Initializing inference params with max_seqlen=21
Prompt: "CCGTACCTCT#",	Output: "TGCATGGAGA",	Score: -0.07370338588953018
Generated Sequence: TGCATGGAGA
Initializing inference params with max_seqlen=21
Prompt: "GTCAATGTCG#",	Output: "CAGTTACAGC",	Score: -0.012873406521975994
Generated Sequence: CAGTTACAGC
Initializing inference params with max_seqlen=21
Prompt: "TCCTCTTTCC#",	Output: "TGGAGAAAGG",	Score: -0.09253494441509247
Generated Sequence: TGGAGAAAGG
Initializing inference params with max_seqlen=21
Prompt: "CCACGCCTTC#",	Output: "TGTGCGGAAG",	Score: -0.10538376867771149
Generated Sequence: TGTGCGGAAG


Epoch 3: 100%|██████████| 5000/5000 [07:28<00:00, 11.14it/s, loss=0.0396] 


Epoch 3 average loss: 0.0687
Initializing inference params with max_seqlen=21
Prompt: "TAGTTGGACC#",	Output: "GTCAACCTGG",	Score: -0.061817556619644165
Generated Sequence: GTCAACCTGG
Initializing inference params with max_seqlen=21
Prompt: "CCGTACCTCT#",	Output: "AGCATGGAGA",	Score: -0.04560355097055435
Generated Sequence: AGCATGGAGA
Initializing inference params with max_seqlen=21
Prompt: "GTCAATGTCG#",	Output: "CAGTTACAGC",	Score: -0.006032387726008892
Generated Sequence: CAGTTACAGC
Initializing inference params with max_seqlen=21
Prompt: "TCCTCTTTCC#",	Output: "AGGAGAAAGG",	Score: -0.053739745169878006
Generated Sequence: AGGAGAAAGG
Initializing inference params with max_seqlen=21
Prompt: "CCACGCCTTC#",	Output: "GGTGCGGAAG",	Score: -0.054788053035736084
Generated Sequence: GGTGCGGAAG


## Debug Stuff

In [34]:
import inspect
print(inspect.signature(evo2_model.generate))



(prompt_seqs: List[str], n_tokens: int = 500, temperature: float = 1.0, top_k: int = 4, top_p: float = 1.0, batched: bool = True, cached_generation: bool = True, verbose: int = 1, force_prompt_threshold: int = None) -> Tuple[List[str], List[float]]


In [22]:
# Check available attributes and methods of the tokenizer
print(dir(tokenizer))


['__abstractmethods__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_vocab_size', 'clamp', 'cls', 'decode_token', 'detokenize', 'detokenize_batch', 'eod', 'eod_id', 'eos', 'eos_id', 'inv_vocab', 'mask', 'name', 'pad', 'pad_id', 'sep', 'tokenize', 'tokenize_batch', 'vocab', 'vocab_size']
