AlienChen commited on
Commit
a913f53
·
verified ·
1 Parent(s): 2c8e2d0

Update models/peptide_classifiers.py

Browse files
Files changed (1) hide show
  1. models/peptide_classifiers.py +2 -2
models/peptide_classifiers.py CHANGED
@@ -530,9 +530,9 @@ class CrossAttnUnpooled(nn.Module):
530
  h = self.shared(z)
531
  return self.reg(h).squeeze(-1), self.cls(h)
532
 
533
- def load_affinity_predictor(checkpoint_path, device):
534
  """Load trained model from checkpoint."""
535
- checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
536
 
537
  model = CrossAttnUnpooled()
538
 
 
530
  h = self.shared(z)
531
  return self.reg(h).squeeze(-1), self.cls(h)
532
 
533
+ def load_affinity_predictor(device):
534
  """Load trained model from checkpoint."""
535
+ checkpoint = torch.load('./classifier_ckpt/wt_affinity.pt', map_location=device, weights_only=False)
536
 
537
  model = CrossAttnUnpooled()
538