-
Notifications
You must be signed in to change notification settings - Fork 62
Description
Hi,
I successfully ran the finetuning code using config/pretrain/saprot.py and config/Thermostability/saprot.py
Then I newly got these questions
I would really appreciate it if you could answer these.
1. Could you share sequence recovery (or sequence design) code?
I made it in my own way, but not sure whether it is correct
Pseudocode would be.
given these;
initial_tokens = ['M#', 'Ev', 'Vp', 'Qp', 'L#', 'Vy', 'Qd', 'Ya', 'Kv'] (initial sequence)
input_tokens = ['##', 'Ev', '#p', 'Qp', 'L#', '#y', '#d', 'Ya', 'Kv'] (masked sequence in sequence subtoken)
(##, #p, #y, #d)
the model predicts a single token (seq/structure token) solely from the masked token, the structure subtoken could be wrong.
(ex. ## -> Gr, #p -> Gp, #y -> Gp, #d -> Sd)
Then only extract the sequence token from the predicted token and reconstruct it. (structure subtoken is same)
input_tokens = ['##', 'Ev', '#p', 'Qp', 'L#', '#y', '#d', 'Ya', 'Kv']
recovered_tokens = ['G#', 'Ev', 'Gp', 'Qp', 'L#, 'Gy', 'Sd', 'Ya', 'Kv']
2. I also made a code to generate the .mdb file as dataset
I checked that it runs ok. But not sure whether the id can be arbitrary or not. (ex. 550, 5500)
I would appreciate it if you could verify this code compared to yours
Generating .mdb file
'''python
import lmdb
import json
Example data
data = {
"550": {"description": "A0A0J6SSW7", "seq": "M#R#A#A#A#T#L#L#V#T#L#C#V#V#G#A#N#E#A#R#A#GfIwLe..."},
"5500": {"description": "A0A535NFD5", "seq": "AdAvRvEvAvLvRvAvSvGvHdPdFdVdEdAdPpGpEpAaAdFp..."},
# Add more entries here
}
Open (or create) an LMDB environment
env = lmdb.open("my_lmdb_file", map_size=1e9) # map_size is the maximum size (in bytes) of the DB
with env.begin(write=True) as txn:
# Add the length of the dataset for.. return int(self._get("length")) in SaprotFoldseekDataset
length = len(data)
txn.put("length".encode("utf-8"), str(length).encode("utf-8"))
for key, value in data.items():
# Convert the value to a JSON string
value_json = json.dumps(value)
# Store key-value pairs in the database; keys must be bytes
txn.put(key.encode("utf-8"), value_json.encode("utf-8"))
Close the LMDB environment
env.close()
'''
Reading .mdb file
'''python
env = lmdb.open("my_lmdb_file/", readonly=True)
with env.begin() as txn:
cursor = txn.cursor()
for key, value in cursor:
print(key, value)
'''
3. I onced asked whether PEFT is possible and you kindly answered that it is there in SaprotBaseModel.py
In the code, I could see that Lora can be used for downstream task.
In my case, I was hoping to use LoRA for MLM finetuning first in certain protein domain
and then do further finetuning on downstream task.
I somehow made the code but I think no approaches like this were available previously.
So I was asking your opinion. Whether it will be viable approaches or not.
So the steps will be
- Load SaProt model weights
- Use LoRA for MLM finetuning
- Load (SaProt model weights + Lora MLM finetuning weights)
- finetune on downstream task
- Load (SaProt model weights + Lora MLM finetuning weights + Lora downstream finetuning weights)
- Prediction on downstream task
Or simply downstream task can be done by getting the embeddings from the
(SaProt model weights + Lora MLM finetuning weights)
coz above mentioned steps are too complicated
4. When I ran the code using config/pretrain/saprot.py or config/pretrain/saprot.py
It seems that only one model is saved after training
If so, how can I know whether the saved model is the optimal model?
I could see that in Trainer, enable_checkpointing: false.
Should I change it into True and keep track of the result with wandb and find the model?
Thank you for reading long inqueries. It will be very helpful to me :)