트랜스포머 기반의 분자 학습 도구, Uni-Mol
⚠️ 이번 글은 LLM(Large Language Model)의 학습에 사용되는 Transformer에 대한 기초적인 선행지식이 요구됩니다
오늘은 Uni-Mol이 Transformer 모델을 이용해 분자의 3차원 위치 정보를 어떻게 예측하는지 간략하게 살펴보겠습니다. 그리고 그 중에서 히츠에서 필요한 태스크 중 하나인, 단백질-리간드의 결합 포즈 예측을 중심으로 살펴보겠습니다.
ICLR 2023에 실린 Uni-Mol이란?
Uni-Mol은 2023년 ICLR에 실린 “UNI-MOL: A UNIVERSAL 3D MOLECULAR REPRESENTATION LEARNING FRAMEWORK”이라는 연구에서 처음 소개되었습니다. 그 이름이 시사하듯이 다양한 태스크를 수행할 수 있는 framework로써 개발되었으며, 단백질과 분자의 3차원 좌푯값을 예측합니다. 저자들에 따르면 ADMET와 같은 분자의 물성 예측, 3차원 컨포머 생성, 단백질-리간드의 결합 포즈 예측 등에 이용될 수 있다고 합니다. Uni-Mol의 큰 특징은 Transformer 기반 분자 표현 학습(Molecular Representation Learning)이라는 것입니다.
트랜스포머 기반의 Backbone 구조
Uni-Mol은 Transformer 모델을 기반으로 하고 있습니다. 트랜스포머는 모두 알고 계신대로 LLM(Large Language Model) 학습에 주로 사용되는데요. Uni-Mol에서 다양한 딥러닝 모델 중에 트랜스포머를 선택한 이유는 멀리 떨어진 원자들 사이의 정보를 학습하기가 용이하기 때문입니다. 물론 트랜스포머를 그대로 사용할 수는 없습니다. 왜냐하면 트랜스포머는 자연어를 처리하는데 특화된 모델이기 때문입니다. 따라서 다음의 요소들을 추가적으로 고려해야 합니다.
첫번째, 연속적인 값을 다뤄야 합니다. 우리가 예측하고자 하는 3차원 좌표의 좌푯값은 연속적입니다. 반면에 트랜스포머에서 사용되는 토큰의 위칫값은 이산적이죠. 따라서 Positional encoding 과정에서 기존의 트랜스포머처럼 이산적인 값이 아니라 연속적인 값을 다뤄야 합니다. 이를 위해서 가우시안 커널을 사용해 float 타입을 embedding 하는 방식을 이용했습니다.
두번째, Positional encoding 과정은 Rotation과 Translation에 대해 Invariant 해야합니다. 왜냐하면 임의의 공간에서 단백질이든, 리간드든 회전시키면서 위치를 이동해도 결국 같은 단백질과 같은 리간드니까요. 이 두번째 제약을 극복하기 위해 모든 원자 쌍들에 대해 거리(Pair-distance)를 계산합니다. 이렇게 원자간의 거리를 이용하면 이후 과정에서 Rotation과 Translation에 대해 Invarint하게 연산을 수행할 수 있기 때문이죠.
Pair represenetation 계산
이제 Pair representation을 계산할 차례입니다. 특정 원자와 그 원자의 쌍이 되는 하나의 쌍을 Pair라고 하고 그 쌍에 대한 표현 정보를 Pair representation이라 합니다. Pair representation은 다음과 같이 계산합니다.
Uni-Mol은 더 나은 성능을 위해 multi-head attention 을 사용합니다. Query, Key, Value를 각각 H 등분하여 병렬로 연산을 진행합니다.
이렇게 Pair representation을 계산하는 이유는 바로 아래에 등장하는 Context vector를 구하기 위함입니다. 우리 눈에 익숙한 Attention mechinism 식에서 Pair representation 항이 추가된 것을 볼수 있습니다. Context vector는 Pretrain 과정의 최종 출력이 됩니다.
사전학습 진행하기 - Pretraining
기존 Transformer에서는 특정 토큰을 Masking한 후, 이 토큰을 맞히는 방식으로 Encoder의 학습이 진행됩니다. Uni-Mol에서는 두 가지의 Pretraining 모델이 필요합니다. 특정 원자가 어떤 종류의 원소인지 예측하는 모델과 특정 원자의 3차원 좌푯값을 예측하는 모델입니다.
이 두 가지 모델을 학습하기 위해 Uni-Mol에서는 다음의 데이터를 사용했습니다.
- 3차원 좌푯값을 포함하는 분자(리간드) 약 2억개
- 3차원 좌푯값을 포함하는 단백질 포켓 구조 약 320만개
먼저 특정 원자가 어떤 종류의 원소인지를 예측하는 첫 번째 모델에서는 기존의 Transformer처럼 Masking 기법을 활용합니다. 데이터세트에 포함된 분자 혹은 단백질의 임의의 원자를 특수한 기호로 대체합니다. 그리고 이 원자의 3차원 좌푯값은 분자 혹은 단백질에 포함된 전체 원자들의 무게 중심으로 설정합니다.
하지만 두 번째 모델인 3차원 좌푯값을 예측하는 모델에서는 위와 같은 방식을 사용할 수 없습니다. 상술한 것처럼 Uni-Mol이 예측할 원자들의 3차원 좌푯값은 이산적인 값이 아니라 연속적인 값이기 때문입니다.
따라서 Uni-Mol에서는 기존의 Transformer처럼 토큰을 마스킹 하는 대신, Ground truth의 원자 위치에 최대 1Å만큼의 가우시안 노이즈를 추가하는 방식으로 3차원 좌푯값을 변화시켰습니다. 그리고 이렇게 변화된 3차원 좌푯값이 정답 구조와 잘 맞도록 학습이 진행됩니다.
위의 학습 과정을 통해 Atom represention과 Pair representation을 학습하는 것이 이 Pretraining 과정의 목적입니다.
단백질-리간드 결합포즈 예측하기 - Finetuning
앞서 우리는 Pretraining 모델을 학습하는 방법을 알아보았습니다. 읽으시는데 오래 걸리셨죠? 마지막 단계입니다. 이제 우리의 목적인 단백질-리간드 결합 포즈를 예측할 차례입니다. 예측하고 싶은 분자(리간드)와 단백질을 각각 Pretraining 모델에 넣어서 Encoding 해줍니다. 그러면 분자와 단백질 각각에 대한 Atom representation과 Pair representation을 얻을 수 있겠죠?
이제부터는 Decoding 과정입니다. Decoder의 입력은 Encoder의 모든 Representation을 concat한 값, 즉 초기 도킹 포즈에 대한 표현 정보입니다. 이 초기 도킹 포즈는 조금만 무작위로 선정해도 Decoder의 출력값이 매우 크게 변동할 수 있습니다. 따라서 초기 도킹 포즈를 얻기 위해 Pretrained model을 이용해 Encoding 했던 것입니다. 그들의 Representation을 concat하여 Decoder의 입력으로 사용합니다.
Docoder 내부에서는 현재의 Pair-distance matrix와 예측 Pair-distance matrix 의 MSE(Mean Square Error)를 Loss 로 삼아 Back-propagation을 진행하며 학습이 진행됩니다. 이 과정에서 단백질과 리간드의 결합 포즈가 최적화 되는 것입니다. Uni-Mol의 연구진은 Uni-Mol의 단백질-리간드 결합 포츠 예측 성능을 측정하기 위해 CASF-2016의 데이터를 벤치마크로 사용하였습니다.
Finetuning의 결과, Uni-Mol의 단백질-리간드 결합 포즈 예측의 정확도는 80.35% 라고 합니다. CASF-2016 데이터에는 총 285개의 단백질-리간드 쌍이 있습니다. 이 80.35%라는 수치는, 총 285개 단백질-리간드 쌍 중에 정답구조와의 RMSD가 2Å이하로 측정된 결과물이 229개가 나왔다는 의미입니다. 이는 논문이 출간되던 시점을 기준으로 SOTA에 달하는 성능입니다.
마치며
최근 Transformer 기반의 딥러닝 모델이 컴퓨터 비전 뿐만 아니라 다양한 분야에 응용되고 있는데요. Uni-Mol 이후로도 어떤 새로운 모델이 개발될지 기대가 됩니다. 이상으로 오늘의 포스팅을 마칩니다. 그럼 다음에는 도움이 될 수 있는 새로운 주제로 찾아뵙겠습니다.