Published on

[REPLUG] Retrieval-Augmented Black-Box Language Models

Summary

  • Black Box Setting의 LM을 가정하고, 여기에 Retriver 를 적용한다.
  • 다양한 Document 를 적절하게 처리하기 위해 Ensenble 을 적용한다.
  • LM이 필요로하는 Document 를 찾아낼 수 있도록 Retriever를 파인튜닝한다.

Black Box Setting LM with Retriver

REPLUG(REtrieve and PLUG)Retrieval-augmented language modeling framework로, Black Box Setting의 언어 모델에 Retriever를 적용하여 성능을 향상시키는 방법을 제안하는 논문이다. Black Box Setting이란 모델의 파라미터에 접근이 불가능하고, 파인튜닝 역시 어려운 상황을 의미하며, 이와 반대되는 상황은 White Box Setting 이라고도 한다.

논문의 주된 문제 의식 중 하나는 지금까지 Retriver-augmented model framework 에 관한 연구들(kNN-LM, RETRO)이 많이 있었지만 대부분 White Box Setting 을 가정하고, 언어 모델에 대한 파라미터 업데이트를 요구해왔다는 점이다. 하지만 이와 같은 White Box Setting 은

  1. 언어 모델의 크기가 커짐에 따라 파인 튜닝에 필요한 연산량이 매우 커졌다는 점과
  2. API 형태로 파라미터에 접근할 수 없는 형태로 언어 모델이 많이 제공되고 있다는 점

때문에 실제 배포 환경에서 제약이 따를 수 밖에 없다.

하지만 Retriever를 적용하면 Hallucination을 줄이는 데에 도움이 될 뿐만 아니라, Domain Knowledge가 다수 요구되는 환경에서 training time 에 보지 못한 지식에 대해서도 성능 향상을 기대해 볼 수 있다. 이러한 장점을 Black Box Setting 에서도 적절하게 살리는 것이 REPLUG 가 풀고자 하는 문제라고 할 수 있다.

REPLUG

Retrieval Augmented LM

가장 단순한 Retrieval Augmented LM 은 다음과 같은 단계를 거쳐 최종적인 언어 모델의 출력을 얻게 된다.

  1. Context Text를 Retriever 에 입력으로 넣는다.
  2. Retriever 는 가지고 있는 문서의 임베딩과 과 Context Text 의 임베딩을 비교하여 유사도를 측정한다.
  3. 유사도가 높은 K(Top K)개의 문서를 반환한다.
  4. Top-K 문서와 Context Text를 붙여 언어 모델에 입력으로 넣는다.
  5. 최종적인 출력을 얻는다.

논문에서는 Document Retrieval과 Input Reformulation이라는 이름으로 REPLUG의 Retrieval Augmented LM을 설명한다.

replug-inference

Document Retrieval

언어 모델이 Input Context에 대한 응답을 생성하기 위해 필요한 Document를 찾는다는 것은 Input Context와 의미가 가장 유사한 Document를 찾는 것으로 치환하여 생각해 볼 수 있다. 이를 위해서 논문에서는 다른 연구에서 자주 활용되었던 Dual Encoder 구조, 즉 document dd와 Input Context xx를 동일한 encoder 로 encoding하고 그 값에 대한 Cosine Similarity 를 구하는 방법을 활용한다. Encoder를 E\mathbf{E}로 하여 이를 수식으로 표현하면 다음과 같다.

s(d,x)=cos(E(d),E(x))s(d, x) = \cos(\mathbf{E}(d), \mathbf{E}(x))

Input Context는 매번 바뀔 가능성이 높지만, Document와 Encoder는 특수한 경우에만 변경될 것이므로 매번 임베딩 E(d)\mathbf{E}(d)를 구하는 것이 아닌, 캐시해두고 계속 사용하는 것이 효율적이다. 또한 논문에서는 유사도가 높은 임베딩을 GPU를 활용해 빠르게 찾는 FAISS Index(Paper, Github) 를 적용하였다고 한다.

Input Reformulation

Document Retrieval 단계에서 유사도가 높은 순서대로 Top K 개의 문서를 찾았다면 이를 Context Input에 대한 응답을 생성하는 데에 사용하여야 한다. 앞서 살펴보았듯이 가장 기초적인 방법은 모든 Document 와 Input Context 를 합하여 언어 모델에 입력으로 전달하는 것이다. 하지만 이러한 방법은

  1. 입력의 길이가 과도하게 길어져 연산량이 크게 증가하거나, 심할 경우 Context Window Size를 초과하게 될 수도 있다는 점
  2. 검색된 K개의 문서들이 대등한 중요도를 가지고 사용된다는 점

등에 있어 아쉬움이 있다.

REPLUG에서는 앙상블을 적용하여 이를 해결하기 위해 시도하고 있다. [d1,d2,...,dk,x][d_1, d_2, ..., d_k, x]를 하나로 만들어 언어 모델에 입력으로 전달하는 대신 하나의 Document 와 Input Context 쌍 [d1,x][d_1, x], [d2,x][d_2, x], ... [dk,x][d_k, x] 들을 각각 언어 모델에 입력하고, 그렇게 만들어진 총 K개의 Output Probability를 앙상블하여 최종적인 Next Token을 생성하는 것이다. 위의 그림을 참고하면 보다 쉽게 이해가 가능한데, 수식으로 표현하면 다음과 같다.

p(yx,D)=dDp(ydx)λ(d,x)p(y \mid x, \mathcal{D}') = \sum_{d \in \mathcal{D}'} p(y \mid d \circ x) \cdot \lambda(d, x)

여기서 D\mathcal{D}'D\mathcal{D}의 부분 집합으로, Top K 문서들을, aba \circ b은 Concatnation을 의미한다. 또한 λ(d,x)\lambda(d, x)는 문서 dd와 Input Context xx 간의 상대적인 유사도로서 다음과 같이 Softmax 식을 통해 산출된다.

λ(d,x)=es(d,x)dDes(d,x)\lambda(d, x) = \frac{e^{s(d,x)}}{\sum_{d \in \mathcal{D}'} e^{s(d,x)}}

앙상블을 적용하여 최종적인 출력 값이 결정된다고 할 떄, 연산량이 크게 늘어나지 않는가에 대해 확인해보고 싶었다. 연산량이 늘어나게 되면 컴퓨팅 비용이 증가한다는 단점과 더불어 모델의 성능 자체도 비례해서 늘어나는 경우가 많기 때문이다. 논문에서도 이 부분에 대해 언급하고 있는데, Cross Attention 을 적용하여 처리하기 때문에 크게 늘어나지 않는다고 한다.

실제로 계산해보면 다음과 같다.

CasesSelf-Attention 연산량Cross-Attention 연산량
Case 1: [d1+d2+...+dk+x][d_1 + d_2 + ... + d_k + x]O((x+kd)2)O((x + k d)^2)N/A
Case 2: [d1+x],[d2+x],...,[dk+x][d_1 + x], [d_2 + x], ..., [d_k + x]k×O((x+d)2)k \times O((x + d)^2)k×O(xd)k \times O(x d)

Self Attention 은 전체 입력 값의 길이에 대해 기하급수적으로 늘어나는 O(N2)O(N^2)의 복잡도를 가지는 반면 Cross Attention 은 O(NM)O(NM) 의 시간 복잡도를 가지므로 언어 모델에 대해 kk번 추론을 수행하더라도 연산량이 크게 늘어나지 않는다는 것이다.

Training the Dense Retriever

REPLUG 에서는 Retriever 를 언어 모델에 맞게 파인튜닝 하는 방법 또한 제안한다. 즉, 어떤 Input Context에 대해 어떤 문서를 찾아야 언어 모델이 응답을 생성하는 데에 도움이 되는 Retriever를 만들어 보겠다는 것이다. 이를 위해서

  • Retriever 에 의해 어떤 문서가 입력과 얼마나 유사한 정도(Retrieval Likelihood)
  • 그 문서가 언어 모델이 응답을 생성하는 데에 얼마나 유용한지의 정도(LM Likelihood, Perplexity)

두 가지가 비슷해지도록 만드는 것을 Retriever 학습 방향으로 삼는다

replug-inference

Computing Retrieval Likelihood

Retriever가 어떤 문서를 선택할 확률은 Retriever 의 Search Space 인 전체 문서 D\mathcal{D}를 대상으로 진행하는 것이 이상적이다. 하지만 현실적으로 수 많은 문서들의 유사도를 구하는 것은 연산량이 많이 소모되고, 심지어 유사도가 높지 않은 대다수의 문서들의 likelihood 값은 0에 가까울 것이라 비효율적이다. 따라서 논문에서는 다음과 같이 Top K D\mathcal{D}'를 대상으로 Retrieval Likelihood 를 근사하여 사용한다.

PR(dx)=es(d,x)/γdDes(d,x)/γP_R(d \mid x) = \frac{e^{s(d,x)/\gamma}}{\sum_{d \in \mathcal{D}'} e^{s(d,x)/\gamma}}

Computing LM Likelihood

어떤 문서가 언어 모델이 응답을 생성하는 데에 도움이 되는 정도는 LM Perplexity, 즉 Next Token 에 대해 확신하는 정도를 통해 측정한다. 수식으로 표현하면 PLM(yd,x)P_{LM}(y \mid d,x) 이 되는데, 여기서 yy는 실제 ground truth output 이다. 이 값을 dd의 중요도 점수로 활용하여 LM Likelihood 를 계산하는 식을 만들어보면 다음과 같다.

Q(dx,y)=ePLM(yd,x)/βdDePLM(yd,x)/βQ(d \mid x, y) = \frac{e^{P_{LM}(y \mid d,x)/\beta}}{\sum_{d \in \mathcal{D}'} e^{P_{LM}(y \mid d,x)/\beta}}

참고로 Perplexity 는 언어 모델의 Generation 성능을 측정하기 위해 사용되는 메트릭으로 다음과 같이 계산된다.

PLL=1P(w1,w2,w3,,wN)N=1i=1NP(wiw1,w2,,wi1)NPLL = \sqrt[N]{\frac{1}{P(w_1, w_2, w_3, \dots, w_N)}} = \sqrt[N]{\frac{1}{\prod_{i=1}^{N} P(w_i \mid w_1, w_2, \dots, w_{i-1})}}

Loss function

앞의 두 likelihood 값을 고려해 볼 떄, 최종적인 Retriever 의 학습 방향은 다음과 같이 KLD 식으로 나타낼 수 있다. 이를 Loss Function 으로 삼아 Retriever를 업데이트하면 Retrieval Likelihood 와 LM Likelihood 가 유사한 분포를 가지는 Retriever 가 될 것으로 기대할 수 있다.

L=1BxBKL(PR(dx)QLM(dx,y))\mathcal{L} = \frac{1}{|\mathcal{B}|} \sum_{x \in \mathcal{B}} KL \left( P_R(d \mid x) \parallel Q_{LM}(d \mid x, y) \right)

Update of Datastore Index

끝으로 Retriever 가 업데이트되면 한 가지 문제가 되는 것이 캐싱해 둔 문서의 임베딩을 더 이상 쓸 수 없게 된다는 점이다. 이를 해결하기 위해 논문에서는 주기적으로 비동기적인 임베딩 업데이트를 진행하였다고 한다.

Training Settings

Language Model

REPLUG의 Retriever 로는 Contriever를 사용하였으며, REPLUG LSR 에서 LM 모델로는 GPT-3 Curie를 활용하였다.

Training Data

학습 데이터는 Pile Training Data에서 256 토큰 길이를 가진 800K 개의 Sequence 를 추출하였는데, 앞의 128 토큰은 Input context xx 로, 나머지 128 토큰은 ground truth yy로 설정하는 식으로 구성했다고 한다.

Documents

외부 Document Corpus D\mathcal{D} 또한 Pile Training Data 에서 36M 개의 문서를 샘플링하여 확보하였다. 이때 Trivial Retrieval 문제, 즉 Pre-Training corpus 에 Document corpus 가 포함되어 발생하는 문제를 줄이기 위해 Training Data 와 Document Corpus 간에는 중복이 발생하지 않도록 처리한다. Retriever 가 Document 를 뽑을 떄에는 FAISS Index 를 활용하여 k=20k=20 으로 설정하였다고 한다.

Experiments

Language Model Performance

replug-experiment-1

GPT-2, GPT-3 에 REPLUG를 적용한 경우와 그렇지 않은 경우를 비교하는 실험을 진행했다. 언어 모델은 Black Box Setting 을 가정하여 API로만 접근 가능하도록 하였다.

모든 실험 셋팅에서 REPLUG와 REPLUG LSR이 Frozen 모델에 비해 높은 점수를 보였으며, 이를 통해 Retriever 를 앞에 붙이면 다양한 크기와 성능을 가진 언어 모델에 대해 성능 향상을 기대해 볼 수 있다는 것을 알 수 있다.

또한 REPLUG와 Retriever 에 대한 학습을 진행한 결과인 REPLUG LSR 를 비교해보면 REPLUG LSR 이 항상 더 우수하며, 이는 Retriever 추가 학습의 유용성을 보여준다.

Downstream Task: MMLU

replug-experiment-2

MMLU(Massive Multi-task Language Understanding) 는 아래와 같이 다양한 분야에 대한 4지 선다형 문제들을 푸는 벤치마크이다. 5 Shot in-context 를 적용하여 다양한 크기의 모델에서 실험을 진행하였다.

replug-mmlu-example

175B의 크기를 갖는 Codex에 REPLUG를 적용하여 스코어가 높아지는 것을 확인할 수 있다. PaLM과 같이 크기가 3배 가까이 큰 모델과 비교해 보더라도 더 좋거나, 약간 낮은 수준의 성능을 보여주었다. 이를 통해 REPLUG를 적용함으로써 모델의 문제 해결 능력을 향상시킬 수 있다는 것을 알 수 있다.

Downstream Task: Open Domain QA

replug-experiment-3

끝으로 Open-Domain QA Task 에서의 실험 결과를 살펴보자. 데이터셋으로는 Natural Questions (NQ)TriviaQA (TQA)을 사용하고 있다. 결과적으로 적은 수준이지만 Codex 모델에서 REPLUG를 적용함으로써 성능이 향상되었고, 기존에 SOTA의 성능을 보여주던 Atlas 보다도 더욱 좋은 성과를 보였다.