TL;DR

  • ์˜๋ฏธ์  ์œ ์‚ฌ์„ฑ์„ ๋ฐ˜์˜ํ•˜๊ธฐ ์œ„ํ•ด PLM BERT๋ฅผ ํ™œ์šฉํ•˜์—ฌ token-level ์œ ์‚ฌ์„ฑ์„ ๋ฐ˜์˜.
  • IDF๋ฅผ ๊ฐ€์ค‘์น˜๋กœ ํ™œ์šฉํ•˜์—ฌ importance weighting ์ ์šฉ.

ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•  ๋•Œ, ๋‹จ์ˆœํ•œ ์–ดํœ˜์  ์ผ์น˜(lexical matching) ์œผ๋กœ๋งŒ ํ‰๊ฐ€ํ•˜๋Š” ๊ฒƒ์€ ํ•œ๊ณ„๊ฐ€ ์žˆ๋‹ค. ์ด๋Ÿฌํ•œ ํ•œ๊ณ„๋ฅผ ๊ทน๋ณตํ•˜๊ธฐ ์œ„ํ•ด ์ œ์•ˆ๋œ ๋ฐฉ๋ฒ• ์ค‘ ํ•˜๋‚˜์ธ BERTScore๋Š” BERT์™€ ๊ฐ™์€ ์‚ฌ์ „ ํ•™์Šต๋œ ์–ธ์–ด ๋ชจ๋ธ์„ ํ™œ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ชจ๋ธ์˜ ์ถœ๋ ฅ๊ณผ ์ฐธ์กฐ ๋ฌธ์žฅ ๊ฐ„์˜ ์˜๋ฏธ์  ์œ ์‚ฌ์„ฑ์„ ํ‰๊ฐ€ํ•œ๋‹ค. ์ด๋Š” ์ „ํ†ต์ ์ธ n-gram ๊ธฐ๋ฐ˜ ํ‰๊ฐ€ ์ง€ํ‘œ๋ณด๋‹ค ๋ฌธ๋งฅ์  ์˜๋ฏธ๋ฅผ ๋” ์ž˜ ๋ฐ˜์˜ํ•˜์—ฌ, ๋‹ค์–‘ํ•œ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ ๊ณผ์ œ์—์„œ ๋” ์‹ ๋ขฐํ•  ์ˆ˜ ์žˆ๋Š” ํ‰๊ฐ€๋ฅผ ์ œ๊ณตํ•œ๋‹ค.

Methods

  1. candidate ์™€ reference๋ฅผ BERT์— ํ†ต๊ณผ์‹œ์ผœ contextual embedding ๊ฐ’์„ ์–ป๋Š”๋‹ค.
  2. token pair ๋งˆ๋‹ค cosine similarity๋ฅผ ์ด์šฉํ•˜์—ฌ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.
  3. ์œ ์‚ฌ๋„ ํ–‰๋ ฌ์„ ๊ธฐ์ค€์œผ๋กœ BERTScore recall์€ row-wise max pooling (reference ๊ฐ ํ† ํฐ ๊ธฐ์ค€), precision์€ column-wise max pooling (candidate ๊ฐ ํ† ํฐ ๊ธฐ์ค€)์œผ๋กœ ๊ตฌํ•œ๋‹ค.
  4. (optional) IDF๋ฅผ ์ด์šฉํ•˜์—ฌ importance weighting ์„ ๋ฐ˜์˜ํ•œ๋‹ค.

Examples

๋‹ค์Œ ๋‘ ๋ฌธ์žฅ์— ๋Œ€ํ•ด,

  • Reference ๋ฌธ์žฅ (A): [โ€œcatโ€, โ€œonโ€, โ€œmatโ€]
  • Candidate ๋ฌธ์žฅ (B): [โ€œcatโ€, โ€œsitsโ€, โ€œonโ€, โ€œrugโ€]

BERT๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ๋‹จ์–ด์˜ ์ž„๋ฒ ๋”ฉ์„ ๊ณ„์‚ฐํ•˜๊ณ , ๋‘ ๋ฌธ์žฅ ๊ฐ„์˜ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ํ–‰๋ ฌ์„ ๊ตฌํ–ˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜์ž:

Token (A)cat (B)sits (B)on (B)rug (B)
cat (A)0.950.200.100.05
on (A)0.100.300.900.10
mat (A)0.050.400.200.80

Recall(row-wise max pooling)์€ ์ฐธ์กฐ๋ฌธ์žฅ A์—์„œ ์–ผ๋งˆ๋‚˜ ๋งŽ์€ ์ •๋ณด๊ฐ€ ์˜ˆ์ธก๋ฌธ์žฅ B์— ์ž˜ ๋ฐ˜์˜๋˜์—ˆ๋Š”์ง€๋ฅผ ์ธก์ •ํ•œ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๊ฐ ํ–‰(row) ์—์„œ ์ตœ๋Œ€๊ฐ’ ์„ ์„ ํƒํ•˜๊ณ , ์ด ๊ฐ’๋“ค์˜ ํ‰๊ท ์„ ๊ณ„์‚ฐํ•œ๋‹ค:

Precision(column-wise max pooling) ์€ ํ›„๋ณด๋ฌธ์žฅ B์˜ ๊ฐ ๋‹จ์–ด๊ฐ€ ์ฐธ์กฐ๋ฌธ์žฅ A์˜ ๋‹จ์–ด๋“ค ์ค‘ ์–ผ๋งˆ๋‚˜ ์ž˜ ๋งค์นญ๋˜์—ˆ๋Š”์ง€๋ฅผ ์ธก์ •ํ•œ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๊ฐ ์—ด(column) ์—์„œ ์ตœ๋Œ€๊ฐ’ ์„ ์„ ํƒํ•˜๊ณ , ์ด ๊ฐ’๋“ค์˜ ํ‰๊ท ์„ ๊ณ„์‚ฐํ•œ๋‹ค:

F1 Score๋Š” Precision๊ณผ Recall์˜ ์กฐํ™”ํ‰๊ท ์ด๋ฏ€๋กœ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.


Code

๊ณต์‹ ๊ตฌํ˜„: Tiiiger/bert_score

ํ† ํฐํ™” & Truncation

def sent_encode(tokenizer, sent):
    return tokenizer.encode(
        sent, add_special_tokens=True,
        max_length=tokenizer.model_max_length,  # BERT=512, kcbert=300
        truncation=True  # ์ดˆ๊ณผ ์‹œ ์ž˜๋ผ๋ฒ„๋ฆผ (๊ฒฝ๊ณ  ์—†์Œ)
    )
  • [CLS], [SEP] ์ž๋™ ์ถ”๊ฐ€ โ†’ ์‹ค์ œ ์‚ฌ์šฉ ๊ฐ€๋Šฅ ํ† ํฐ์€ max_length - 2
  • ๋ฌธ์„œ ์ˆ˜์ค€ ์ž…๋ ฅ ์‹œ ๋’ท๋ถ€๋ถ„ ์ •๋ณด ์†์‹ค

๋ ˆ์ด์–ด ์„ ํƒ

# ๋ชจ๋ธ ๋กœ๋“œ ํ›„ ์ง€์ • ๋ ˆ์ด์–ด๊นŒ์ง€๋งŒ ๋ฌผ๋ฆฌ์ ์œผ๋กœ ์ž˜๋ผ๋ƒ„
model.encoder.layer = nn.ModuleList(
    [layer for layer in model.encoder.layer[:num_layers]]
)
# ์˜ˆ: num_layers=9 โ†’ layer 0~8๋งŒ ์‚ฌ์šฉ, layer 8 ์ถœ๋ ฅ์ด ์ตœ์ข… embedding
  • ๊ธฐ๋ณธ๊ฐ’์€ ๋ชจ๋ธ๋ณ„ ์ตœ์  ๋ ˆ์ด์–ด (๋…ผ๋ฌธ Appendix B์—์„œ WMT16์œผ๋กœ ํŠœ๋‹)
  • ์ค‘๊ฐ„ ๋ ˆ์ด์–ด๊ฐ€ semantic similarity์— ์ตœ์ , ์ƒ์œ„ ๋ ˆ์ด์–ด๋Š” MLM์— ํŠนํ™”

ํ•ต์‹ฌ ์Šค์ฝ”์–ด๋ง โ€” greedy_cos_idf

def greedy_cos_idf(ref_embedding, ref_masks, ref_idf,
                   hyp_embedding, hyp_masks, hyp_idf):
 
    # (1) L2 ์ •๊ทœํ™” โ†’ ๋‚ด์  = cosine similarity
    ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
    hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))
 
    # (2) ์œ ์‚ฌ๋„ ํ–‰๋ ฌ: [batch, hyp_len, ref_len]
    sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2))
 
    # (3) Greedy matching
    word_precision = sim.max(dim=2)[0]  # candidate ๊ฐ ํ† ํฐ โ†’ ref ์ตœ๋Œ€ ๋งค์นญ
    word_recall    = sim.max(dim=1)[0]  # reference ๊ฐ ํ† ํฐ โ†’ hyp ์ตœ๋Œ€ ๋งค์นญ
 
    # (4) IDF ๊ฐ€์ค‘ ํ‰๊ท  โ†’ ๋ฌธ์žฅ ์ˆ˜์ค€ ์ ์ˆ˜
    hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True))  # ์ •๊ทœํ™” (ํ•ฉ=1)
    ref_idf.div_(ref_idf.sum(dim=1, keepdim=True))
 
    P = (word_precision * hyp_idf).sum(dim=1)  # Precision
    R = (word_recall    * ref_idf).sum(dim=1)  # Recall
    F = 2 * P * R / (P + R)                    # F1
    return P, R, F

dim ๋ฐฉํ–ฅ ์ดํ•ด โ€” ์œ„ Examples ํ–‰๋ ฌ๊ณผ ๋Œ€์‘

์ฝ”๋“œ์˜ sim์€ Examples ํ–‰๋ ฌ๊ณผ ์ „์น˜ ๊ด€๊ณ„

  • Examples: ํ–‰=Reference(A), ์—ด=Candidate(B) โ†’ shape [3, 4]
  • ์ฝ”๋“œ: ํ–‰=Candidate(hyp), ์—ด=Reference(ref) โ†’ shape [4, 3]

Examples ํ–‰๋ ฌ์„ ์ „์น˜ํ•˜๋ฉด ์ฝ”๋“œ์˜ sim์ด ๋œ๋‹ค:

์—ฐ์‚ฐdim์ถ• ๋ฐฉํ–ฅ์˜๋ฏธExamples ๋Œ€์‘
sim.max(dim=2)ref ์ถ• ์ œ๊ฑฐโ†’ ๋ฐฉํ–ฅhyp ๊ฐ ํ† ํฐ์˜ best match in refcolumn-wise max (Precision)
sim.max(dim=1)hyp ์ถ• ์ œ๊ฑฐโ†“ ๋ฐฉํ–ฅref ๊ฐ ํ† ํฐ์˜ best match in hyprow-wise max (Recall)

๋‹จ๊ณ„๋ณ„ ์ •๋ฆฌ

๋‹จ๊ณ„์—ฐ์‚ฐ์ฝ”๋“œ
์œ ์‚ฌ๋„ ํ–‰๋ ฌL2 norm ํ›„ batch matmulbmm(hyp, ref.T) โ†’ [B, H, R]
Precisionhyp ๊ฐ ํ† ํฐ์˜ ref ์ตœ๋Œ€ ๋งค์นญsim.max(dim=2)
Recallref ๊ฐ ํ† ํฐ์˜ hyp ์ตœ๋Œ€ ๋งค์นญsim.max(dim=1)
IDF ๊ฐ€์ค‘์ •๊ทœํ™”๋œ IDF๋ฅผ ๊ฐ€์ค‘์น˜๋กœ ๊ณฑ(word_score * idf).sum()
F1์กฐํ™”ํ‰๊ท F = 2PR/(P+R)

IDF ๊ณ„์‚ฐ

def get_idf_dict(arr, tokenizer, nthreads=4):
    idf_count = Counter()
    num_docs = len(arr)
 
    # (1) ๊ฐ reference ๋ฌธ์žฅ์„ ํ† ํฐํ™”ํ•˜์—ฌ DF ์ง‘๊ณ„
    for sent in arr:
        tokens = sent_encode(tokenizer, sent)
        for token_id in set(tokens):     # set โ†’ ๋ฌธ์žฅ ๋‚ด ์ค‘๋ณต ๋ฌด์‹œ
            idf_count[token_id] += 1     # ํ•ด๋‹น ํ† ํฐ์ด ๋“ฑ์žฅํ•œ ๋ฌธ์žฅ ์ˆ˜ (DF)
 
    # (2) IDF = log((N+1)/(df+1)) โ€” Laplace smoothing
    idf_dict = defaultdict(
        lambda: log((num_docs + 1) / 1)  # ๋ฏธ๋“ฑ์žฅ ํ† ํฐ ๊ธฐ๋ณธ๊ฐ’ (์ตœ๋Œ€ IDF)
    )
    idf_dict.update({
        idx: log((num_docs + 1) / (c + 1))
        for (idx, c) in idf_count.items()
    })
    return idf_dict
  • (1) reference ์ฝ”ํผ์Šค์˜ ๊ฐ ๋ฌธ์žฅ์„ ์ˆœํšŒํ•˜๋ฉฐ ํ† ํฐ๋ณ„ ๋ฌธ์„œ ๋นˆ๋„(DF) ์ง‘๊ณ„. set(tokens)์œผ๋กœ ํ•œ ๋ฌธ์žฅ ๋‚ด ์ค‘๋ณต ์นด์šดํŒ… ๋ฐฉ์ง€
  • (2) Laplace smoothing์œผ๋กœ 0-division ๋ฐฉ์ง€. ๋ฏธ๋“ฑ์žฅ ํ† ํฐ์€ ์ตœ๋Œ€ IDF ๋ถ€์—ฌ
  • IDF ๋ฏธ์‚ฌ์šฉ ์‹œ (idf=False) ๋ชจ๋“  ํ† ํฐ ๊ท ๋“ฑ ๊ฐ€์ค‘

Key Concepts to Clarify

  • ๋ ˆ์ด์–ด ์„ ํƒ: BERT์˜ ๋ชจ๋“  ๋ ˆ์ด์–ด๊ฐ€ ๋™์ผํ•˜๊ฒŒ ์œ ์šฉํ•˜์ง€ ์•Š๋‹ค. ๋…ผ๋ฌธ์—์„œ ์ค‘๊ฐ„ ๋ ˆ์ด์–ด๊ฐ€ semantic similarity์— ์ตœ์ ์ด๋ฉฐ, ์ตœ์ข… ๋ ˆ์ด์–ด๋Š” pretraining objective์— ํŠนํ™”๋˜์–ด ์„ฑ๋Šฅ์ด ๋–จ์–ด์ง์„ ํ™•์ธ. ๊ฐ ๋ชจ๋ธ๋ณ„๋กœ WMT16์„ validation์œผ๋กœ ์ตœ์  ๋ ˆ์ด์–ด๋ฅผ ํƒ์ƒ‰ (Appendix B).
  • Greedy matching vs Optimal matching: BERTScore๋Š” greedy matching (๊ฐ ํ† ํฐ์„ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์ƒ๋Œ€ ํ† ํฐ์— ๋งค์นญ) ์„ ํƒ. WMD ๊ธฐ๋ฐ˜ optimal matching(Earth Moverโ€™s Distance)์œผ๋กœ ๊ต์ฒดํ•ด๋„ ์ผ๊ด€๋œ ๊ฐœ์„  ์—†์Œ (Appendix C). MoverScore๋Š” ๊ฐ™์€ ๋งฅ๋ฝ์—์„œ optimal ์„ ํƒ.
  • Baseline rescaling: cosine similarity ๋ฒ”์œ„๊ฐ€ ์ด๋ก ์ƒ ์ด๋‚˜ ์‹ค์ œ๋กœ๋Š” ์ข์€ ๊ตฌ๊ฐ„์— ๋ถ„ํฌ. Common Crawl์—์„œ ๋žœ๋ค ๋ฌธ์žฅ ์Œ์œผ๋กœ empirical lower bound ๋ฅผ ๊ตฌํ•ด ๋กœ rescalingํ•˜์—ฌ ๊ฐ€๋…์„ฑ ํ–ฅ์ƒ. ๋žญํ‚น ๋Šฅ๋ ฅ์—๋Š” ์˜ํ–ฅ ์—†์Œ.

Connections

  • MoverScore โ€” ๊ฐ™์€ contextual embedding ๊ธฐ๋ฐ˜์ด๋‚˜ optimal matching(Word Moverโ€™s Distance,WMD) ์„ ํƒ
  • BLEU โ€” BERTScore๊ฐ€ ํ•ด๊ฒฐํ•˜๋ ค๋Š” n-gram ๊ธฐ๋ฐ˜ ๋ฉ”ํŠธ๋ฆญ์˜ ๋Œ€ํ‘œ
  • METEOR โ€” stem/synonym fallback์œผ๋กœ exact match ๋ณด์™„, BERTScore์˜ ์„ ํ–‰

Source Trail