Update models/peptide_classifiers.py
Browse files
models/peptide_classifiers.py
CHANGED
|
@@ -530,9 +530,9 @@ class CrossAttnUnpooled(nn.Module):
|
|
| 530 |
h = self.shared(z)
|
| 531 |
return self.reg(h).squeeze(-1), self.cls(h)
|
| 532 |
|
| 533 |
-
def load_affinity_predictor(
|
| 534 |
"""Load trained model from checkpoint."""
|
| 535 |
-
checkpoint = torch.load(
|
| 536 |
|
| 537 |
model = CrossAttnUnpooled()
|
| 538 |
|
|
|
|
| 530 |
h = self.shared(z)
|
| 531 |
return self.reg(h).squeeze(-1), self.cls(h)
|
| 532 |
|
| 533 |
+
def load_affinity_predictor(device):
|
| 534 |
"""Load trained model from checkpoint."""
|
| 535 |
+
checkpoint = torch.load('./classifier_ckpt/wt_affinity.pt', map_location=device, weights_only=False)
|
| 536 |
|
| 537 |
model = CrossAttnUnpooled()
|
| 538 |
|