또르르's 개발 Story

[Stage 3 - 논문 리뷰] CHAN-DST 논문 리뷰 본문

[P Stage 3] DST/논문리뷰

[Stage 3 - 논문 리뷰] CHAN-DST 논문 리뷰

또르르21 2021. 5. 20. 02:04

CHAN-DST

  • slot imbalance 문제를 해결하고자 adaptive objective를 도입.
  • a contextual hierarchical attention network (CHAN)를 사용 : dislogue history에서 relevant context를 찾기 위함. → 각 턴의 발화로부터 word-level 관련 정보 검색 → contextual representation으로 encode → 모든 context표현을 turn-level관련 정보로 집계한 후 word-level 정보와 합친 output 생성.
  • state transition prediction task

Definition

  • T : turn
  • Ut : user utterance of turn t
  • Rt : system response of turn t
  • X : {(U1,R1),...,(UT,RT) }
  • Bt : {(s,vt),sS}
  • S : set of slots vt : corresponding value of the slot s slot : concatenation of a domain name and a slot name

Contextual Hierarchical Attention Network

1. Sentence Encoder

utterance encoder

  • BERT special token사용 → [CLS] : 문장의 representation들을 합치기위해 사용 (to aggregate the whole representation of a sentence) → [SEP] : 문장의 끝을 나타내기위해 사용.
  • Ut = {wu1,...,wul} (user utterance) Rt = {wr1,...,wrl} (system response) t : dialogue turn
  • ht = BERTfinetune([Rt;Ut]) (ht : contextual word representations)
  • 여기서 BERT finetune은 training도중 finetuning이 될것을 의미.

slot-value encoder

  • BERTfixed는 contextual semantics vectors로 encode해준다.
  • utterance encode할때와 다른 점은 [CLS] 토큰의 output vector를 전체 문장 representation할때 사용한다. (to obtain the whole sentence representation)
  • hs = BERTfixed(s) hvt = BERTfixed(vt)
  • BERTfixed는 training 도중 고정되어있다. 그래서 우리 모델은 unseen slots and values에 대해서 original BERT representation로 확장해서 보는게 가능하다.

2. Slot-Word Attention

  • slot-word attention은 multi-head attention을 사용한다.
  • cwords,t = MultiHead(hs,ht,ht)

3. Context Encoder

  • context encoder : unidirectional transformer encoder
  • {1, ..., t} 턴에서 추출 된 word-level slot-related 정보의 contextual relevance를 모델링하기 위한 것.
  • N개의 idenctical한 layer가 있다.
    • 각 layer는 2개의 sub-layer를 가지고 있다.
    • 첫번째 sub-layer : masked multi-head self-attention(Q = K = V)
    • 두번째 sub-layer : position-wise fully connected feed-forward network(FFN) (two linear transformations, RELU activation으로 구성)
    • FFN(x) = max(0, xW1 + b1)W2 + b2
  • mn=FFN(MultiHead(mn1,mn1,mn1)) m0=[cwords,1+PE(1),...,cwords,t+PE(t)] cctxs,t=mN
  • mn : n번째 context encoder레이어의 아웃풋 PE(.) : positional encoding function

4. Slot-Turn Attention

  • turn-level relevant information을 contextual representation에서 검출해내기 위해 사용
  • cturns,t=MultiHead(hs,cctxs,t,cctxs,t)
  • 이로인해 word-level and turn-level 의 relevant information을 historical dialogues에서 얻어낼 수 있다.

5. Global-Local Fusion Gate

  • global context와 local utterance의 균형을 맞추기 위해, contextual information과 current turn information의 비율을 조절함.
  • cwords,t,  cturns,t에 따라 global과 local정보가 어떻게 결합되어야할지 알려주는 fusion gate mechanism을 사용
  • gs,t = σ(Wg[cwords,t;cturns,t])
  • cgs,tate=gs,tcwords,t+(1gs,tcturns,t)
    • WgR2d×d
    • σ : Sigmoid
    • , :
  • os,t = LayerNorm(Linear(Dropout(cgates,t)))
  • value vt에 대한 확률분포와 training objective p(vt|Ut, Rt,s) = exp(||os,thvt||2)vVsexp(||os,thvt||2) Ldst = sSTt=1log(p(ˆvt|Ut, Rt,s))
    • Vs : candidate value set of slot s
    • ˆvtVs : ground-truth value of slot s

State Transition Prediction

  • relevant context를 더 잘 포착하기 위해, auxiliary binary classification task사용.
  • cstps,t = tanh(Wccgates,t)
  • pstps,t = σ(Wp[cstps,t;cstps,t1])
    • Wc\Rd×d
    • Wc\R2d
    • t = 1일때는 cstps,t와 zero vectors를 concat함.
  • binary CE loss (ystps,t : ground-truth transition labels // pstps,t : transition probability)
  • Lstp = sSTt=1ystps,t . log(pstps,t)

Adaptive Objective

  • hard slots와 samples에 관한 optimization을 encourage한다.
  • all slots의 learning을 balancing함.
  • accvals : accuracy of slot s on validation set
  • slot-level difficulty if accvalsaccvals ; → slot s 가 slot s'보다 더 어려운 것. → α : slot-level difficulty
    • αs = 1accvalssS1accvals.|S|
  • sample-level difficulty → Suppose there are two samples {(Ut,Rt),(s,vt)} and {(Ut,Rt),(s,vt)}. → 만약 former confidence 가 latter보다 더 낮다면, 첫번째 sample이 두번째보다 더 어려운 것. → β : sample level difficulty
    • β(s,vt)=(1p(s,vt))γ
    • p(s,vt) : confidence of sample (Ut,Rt),(s,vt) γ : hyper-parameter
  • Ladapt(s,vt)=αsβ(s,vt)log(p(s,vt))
  • slot s가 평균 slot의 difficulty보다 높다면, αs는 s에 대한 loss를 키울 것이다. 비슷하게, sample의 optimization이 low confidence를 갖고 있다면 loss는 커질것이다.

Optimization

  • During joint training, we optimize the sum of these two loss functions as following Ljoint=Ldst+Lstp
  • At the fine-tuning phase, we adopt the adaptive objective to fine-tune DST task as following: Lfinetune = sSTt=1Ladapt(s,ˆvt)
Comments