TL;DR
- ์๋ฏธ์ ์ ์ฌ์ฑ์ ๋ฐ์ํ๊ธฐ ์ํด PLM BERT๋ฅผ ํ์ฉํ์ฌ token-level ์ ์ฌ์ฑ์ ๋ฐ์.
- IDF๋ฅผ ๊ฐ์ค์น๋ก ํ์ฉํ์ฌ importance weighting ์ ์ฉ.
ํ ์คํธ ์์ฑ ๋ชจ๋ธ์ ํ๊ฐํ ๋, ๋จ์ํ ์ดํ์ ์ผ์น(lexical matching) ์ผ๋ก๋ง ํ๊ฐํ๋ ๊ฒ์ ํ๊ณ๊ฐ ์๋ค. ์ด๋ฌํ ํ๊ณ๋ฅผ ๊ทน๋ณตํ๊ธฐ ์ํด ์ ์๋ ๋ฐฉ๋ฒ ์ค ํ๋์ธ BERTScore๋ BERT์ ๊ฐ์ ์ฌ์ ํ์ต๋ ์ธ์ด ๋ชจ๋ธ์ ํ์ฉํ์ฌ ํ ์คํธ ์์ฑ ๋ชจ๋ธ์ ์ถ๋ ฅ๊ณผ ์ฐธ์กฐ ๋ฌธ์ฅ ๊ฐ์ ์๋ฏธ์ ์ ์ฌ์ฑ์ ํ๊ฐํ๋ค. ์ด๋ ์ ํต์ ์ธ n-gram ๊ธฐ๋ฐ ํ๊ฐ ์งํ๋ณด๋ค ๋ฌธ๋งฅ์ ์๋ฏธ๋ฅผ ๋ ์ ๋ฐ์ํ์ฌ, ๋ค์ํ ์์ฐ์ด ์ฒ๋ฆฌ ๊ณผ์ ์์ ๋ ์ ๋ขฐํ ์ ์๋ ํ๊ฐ๋ฅผ ์ ๊ณตํ๋ค.
Methods


- candidate ์ reference๋ฅผ BERT์ ํต๊ณผ์์ผ contextual embedding ๊ฐ์ ์ป๋๋ค.
- token pair ๋ง๋ค cosine similarity๋ฅผ ์ด์ฉํ์ฌ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๋ค.
- ์ ์ฌ๋ ํ๋ ฌ์ ๊ธฐ์ค์ผ๋ก BERTScore recall์ row-wise max pooling (reference ๊ฐ ํ ํฐ ๊ธฐ์ค), precision์ column-wise max pooling (candidate ๊ฐ ํ ํฐ ๊ธฐ์ค)์ผ๋ก ๊ตฌํ๋ค.
- (optional) IDF๋ฅผ ์ด์ฉํ์ฌ importance weighting ์ ๋ฐ์ํ๋ค.
Examples
๋ค์ ๋ ๋ฌธ์ฅ์ ๋ํด,
- Reference ๋ฌธ์ฅ (A): [โcatโ, โonโ, โmatโ]
- Candidate ๋ฌธ์ฅ (B): [โcatโ, โsitsโ, โonโ, โrugโ]
BERT๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ๋จ์ด์ ์๋ฒ ๋ฉ์ ๊ณ์ฐํ๊ณ , ๋ ๋ฌธ์ฅ ๊ฐ์ ์ฝ์ฌ์ธ ์ ์ฌ๋ ํ๋ ฌ์ ๊ตฌํ๋ค๊ณ ๊ฐ์ ํ์:
| Token (A) | cat (B) | sits (B) | on (B) | rug (B) |
|---|---|---|---|---|
| cat (A) | 0.95 | 0.20 | 0.10 | 0.05 |
| on (A) | 0.10 | 0.30 | 0.90 | 0.10 |
| mat (A) | 0.05 | 0.40 | 0.20 | 0.80 |
Recall(row-wise max pooling)์ ์ฐธ์กฐ๋ฌธ์ฅ A์์ ์ผ๋ง๋ ๋ง์ ์ ๋ณด๊ฐ ์์ธก๋ฌธ์ฅ B์ ์ ๋ฐ์๋์๋์ง๋ฅผ ์ธก์ ํ๋ค. ์ด๋ฅผ ์ํด ๊ฐ ํ(row) ์์ ์ต๋๊ฐ ์ ์ ํํ๊ณ , ์ด ๊ฐ๋ค์ ํ๊ท ์ ๊ณ์ฐํ๋ค:
Precision(column-wise max pooling) ์ ํ๋ณด๋ฌธ์ฅ B์ ๊ฐ ๋จ์ด๊ฐ ์ฐธ์กฐ๋ฌธ์ฅ A์ ๋จ์ด๋ค ์ค ์ผ๋ง๋ ์ ๋งค์นญ๋์๋์ง๋ฅผ ์ธก์ ํ๋ค. ์ด๋ฅผ ์ํด ๊ฐ ์ด(column) ์์ ์ต๋๊ฐ ์ ์ ํํ๊ณ , ์ด ๊ฐ๋ค์ ํ๊ท ์ ๊ณ์ฐํ๋ค:
F1 Score๋ Precision๊ณผ Recall์ ์กฐํํ๊ท ์ด๋ฏ๋ก ๋ค์๊ณผ ๊ฐ๋ค.
Code
๊ณต์ ๊ตฌํ: Tiiiger/bert_score
ํ ํฐํ & Truncation
def sent_encode(tokenizer, sent):
return tokenizer.encode(
sent, add_special_tokens=True,
max_length=tokenizer.model_max_length, # BERT=512, kcbert=300
truncation=True # ์ด๊ณผ ์ ์๋ผ๋ฒ๋ฆผ (๊ฒฝ๊ณ ์์)
)[CLS],[SEP]์๋ ์ถ๊ฐ โ ์ค์ ์ฌ์ฉ ๊ฐ๋ฅ ํ ํฐ์max_length - 2- ๋ฌธ์ ์์ค ์ ๋ ฅ ์ ๋ท๋ถ๋ถ ์ ๋ณด ์์ค
๋ ์ด์ด ์ ํ
# ๋ชจ๋ธ ๋ก๋ ํ ์ง์ ๋ ์ด์ด๊น์ง๋ง ๋ฌผ๋ฆฌ์ ์ผ๋ก ์๋ผ๋
model.encoder.layer = nn.ModuleList(
[layer for layer in model.encoder.layer[:num_layers]]
)
# ์: num_layers=9 โ layer 0~8๋ง ์ฌ์ฉ, layer 8 ์ถ๋ ฅ์ด ์ต์ข
embedding- ๊ธฐ๋ณธ๊ฐ์ ๋ชจ๋ธ๋ณ ์ต์ ๋ ์ด์ด (๋ ผ๋ฌธ Appendix B์์ WMT16์ผ๋ก ํ๋)
- ์ค๊ฐ ๋ ์ด์ด๊ฐ semantic similarity์ ์ต์ , ์์ ๋ ์ด์ด๋ MLM์ ํนํ
ํต์ฌ ์ค์ฝ์ด๋ง โ greedy_cos_idf
def greedy_cos_idf(ref_embedding, ref_masks, ref_idf,
hyp_embedding, hyp_masks, hyp_idf):
# (1) L2 ์ ๊ทํ โ ๋ด์ = cosine similarity
ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))
# (2) ์ ์ฌ๋ ํ๋ ฌ: [batch, hyp_len, ref_len]
sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2))
# (3) Greedy matching
word_precision = sim.max(dim=2)[0] # candidate ๊ฐ ํ ํฐ โ ref ์ต๋ ๋งค์นญ
word_recall = sim.max(dim=1)[0] # reference ๊ฐ ํ ํฐ โ hyp ์ต๋ ๋งค์นญ
# (4) IDF ๊ฐ์ค ํ๊ท โ ๋ฌธ์ฅ ์์ค ์ ์
hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True)) # ์ ๊ทํ (ํฉ=1)
ref_idf.div_(ref_idf.sum(dim=1, keepdim=True))
P = (word_precision * hyp_idf).sum(dim=1) # Precision
R = (word_recall * ref_idf).sum(dim=1) # Recall
F = 2 * P * R / (P + R) # F1
return P, R, Fdim ๋ฐฉํฅ ์ดํด โ ์ Examples ํ๋ ฌ๊ณผ ๋์
์ฝ๋์
sim์ Examples ํ๋ ฌ๊ณผ ์ ์น ๊ด๊ณ
- Examples: ํ=Reference(A), ์ด=Candidate(B) โ shape
[3, 4]- ์ฝ๋: ํ=Candidate(hyp), ์ด=Reference(ref) โ shape
[4, 3]
Examples ํ๋ ฌ์ ์ ์นํ๋ฉด ์ฝ๋์ sim์ด ๋๋ค:

| ์ฐ์ฐ | dim | ์ถ ๋ฐฉํฅ | ์๋ฏธ | Examples ๋์ |
|---|---|---|---|---|
sim.max(dim=2) | ref ์ถ ์ ๊ฑฐ | โ ๋ฐฉํฅ | hyp ๊ฐ ํ ํฐ์ best match in ref | column-wise max (Precision) |
sim.max(dim=1) | hyp ์ถ ์ ๊ฑฐ | โ ๋ฐฉํฅ | ref ๊ฐ ํ ํฐ์ best match in hyp | row-wise max (Recall) |
๋จ๊ณ๋ณ ์ ๋ฆฌ
| ๋จ๊ณ | ์ฐ์ฐ | ์ฝ๋ |
|---|---|---|
| ์ ์ฌ๋ ํ๋ ฌ | L2 norm ํ batch matmul | bmm(hyp, ref.T) โ [B, H, R] |
| Precision | hyp ๊ฐ ํ ํฐ์ ref ์ต๋ ๋งค์นญ | sim.max(dim=2) |
| Recall | ref ๊ฐ ํ ํฐ์ hyp ์ต๋ ๋งค์นญ | sim.max(dim=1) |
| IDF ๊ฐ์ค | ์ ๊ทํ๋ IDF๋ฅผ ๊ฐ์ค์น๋ก ๊ณฑ | (word_score * idf).sum() |
| F1 | ์กฐํํ๊ท | F = 2PR/(P+R) |
IDF ๊ณ์ฐ
def get_idf_dict(arr, tokenizer, nthreads=4):
idf_count = Counter()
num_docs = len(arr)
# (1) ๊ฐ reference ๋ฌธ์ฅ์ ํ ํฐํํ์ฌ DF ์ง๊ณ
for sent in arr:
tokens = sent_encode(tokenizer, sent)
for token_id in set(tokens): # set โ ๋ฌธ์ฅ ๋ด ์ค๋ณต ๋ฌด์
idf_count[token_id] += 1 # ํด๋น ํ ํฐ์ด ๋ฑ์ฅํ ๋ฌธ์ฅ ์ (DF)
# (2) IDF = log((N+1)/(df+1)) โ Laplace smoothing
idf_dict = defaultdict(
lambda: log((num_docs + 1) / 1) # ๋ฏธ๋ฑ์ฅ ํ ํฐ ๊ธฐ๋ณธ๊ฐ (์ต๋ IDF)
)
idf_dict.update({
idx: log((num_docs + 1) / (c + 1))
for (idx, c) in idf_count.items()
})
return idf_dict- (1) reference ์ฝํผ์ค์ ๊ฐ ๋ฌธ์ฅ์ ์ํํ๋ฉฐ ํ ํฐ๋ณ ๋ฌธ์ ๋น๋(DF) ์ง๊ณ.
set(tokens)์ผ๋ก ํ ๋ฌธ์ฅ ๋ด ์ค๋ณต ์นด์ดํ ๋ฐฉ์ง - (2) Laplace smoothing์ผ๋ก 0-division ๋ฐฉ์ง. ๋ฏธ๋ฑ์ฅ ํ ํฐ์ ์ต๋ IDF ๋ถ์ฌ
- IDF ๋ฏธ์ฌ์ฉ ์ (
idf=False) ๋ชจ๋ ํ ํฐ ๊ท ๋ฑ ๊ฐ์ค
Key Concepts to Clarify
- ๋ ์ด์ด ์ ํ: BERT์ ๋ชจ๋ ๋ ์ด์ด๊ฐ ๋์ผํ๊ฒ ์ ์ฉํ์ง ์๋ค. ๋ ผ๋ฌธ์์ ์ค๊ฐ ๋ ์ด์ด๊ฐ semantic similarity์ ์ต์ ์ด๋ฉฐ, ์ต์ข ๋ ์ด์ด๋ pretraining objective์ ํนํ๋์ด ์ฑ๋ฅ์ด ๋จ์ด์ง์ ํ์ธ. ๊ฐ ๋ชจ๋ธ๋ณ๋ก WMT16์ validation์ผ๋ก ์ต์ ๋ ์ด์ด๋ฅผ ํ์ (Appendix B).
- Greedy matching vs Optimal matching: BERTScore๋ greedy matching (๊ฐ ํ ํฐ์ ๊ฐ์ฅ ์ ์ฌํ ์๋ ํ ํฐ์ ๋งค์นญ) ์ ํ. WMD ๊ธฐ๋ฐ optimal matching(Earth Moverโs Distance)์ผ๋ก ๊ต์ฒดํด๋ ์ผ๊ด๋ ๊ฐ์ ์์ (Appendix C). MoverScore๋ ๊ฐ์ ๋งฅ๋ฝ์์ optimal ์ ํ.
- Baseline rescaling: cosine similarity ๋ฒ์๊ฐ ์ด๋ก ์ ์ด๋ ์ค์ ๋ก๋ ์ข์ ๊ตฌ๊ฐ์ ๋ถํฌ. Common Crawl์์ ๋๋ค ๋ฌธ์ฅ ์์ผ๋ก empirical lower bound ๋ฅผ ๊ตฌํด ๋ก rescalingํ์ฌ ๊ฐ๋ ์ฑ ํฅ์. ๋ญํน ๋ฅ๋ ฅ์๋ ์ํฅ ์์.
Connections
- MoverScore โ ๊ฐ์ contextual embedding ๊ธฐ๋ฐ์ด๋ optimal matching(Word Moverโs Distance,WMD) ์ ํ
- BLEU โ BERTScore๊ฐ ํด๊ฒฐํ๋ ค๋ n-gram ๊ธฐ๋ฐ ๋ฉํธ๋ฆญ์ ๋ํ
- METEOR โ stem/synonym fallback์ผ๋ก exact match ๋ณด์, BERTScore์ ์ ํ
Source Trail
- ์ด ๋ ธํธ๋ @zhangBERTScoreEvaluatingText2020์์ ์ถ์ถ๋จ
- KoBERTScore (ํ๊ตญ์ด ๊ตฌํ)
Discussion
Comments
๋๊ธ์ ์น์ธ ํ ๊ณต๊ฐ๋ฉ๋๋ค.