トピックモデルで最適なトピック数を決定する方法

トピックモデルで最適なトピック数を決定する方法について調べた。

optimal_k


トピック数を変えながらtopicmodelsを使ってモデルを作成し、データから対数尤度の調和平均を求め、最適なトピック数を求める。
この方法を使ったサンプルがTopic Models Learning and R Resourcesにある。
#Control List
control <- list(burnin = 500, iter = 1000, keep = 100, seed = 2500)
(k <- optimal_k(dpc_dtm, 40, control = control))
リスト1. optimai_k

関数optimal_kが最適なkを求める関数である。ここで、dpc_dtmはDocumentTermMatrixである。これを実行すると最適なトピック数 k とともに、対数尤度の調和平均のトピック数依存性のフラフが表示される。
dpc_dtmのプロフィールを以下に示す。
<<DocumentTermMatrix (documents: 1664, terms: 801)>>
Non-/sparse entries: 167793/1165071
Sparsity           : 87%
Maximal term length: 8
Weighting          : term frequency (tf)
リスト2. DocumentTermMatrixの概要

これはDPC参加病院のみ(1664病院、疾患別手術801件)を対象として作成したDocumentTermMatrixである。
これに対してリスト1を実行したところ、全部で10時間25分かかった。
結果を図1に示す。

図1.optimal_kの実行結果
図から、対数尤度の調和平均は飽和していないことがわかる。もう少しトピック数を増やして調べる必要がある。
しかしながら、optimal_kは、トピック数を2~最大トピック数まで1ずつ増やしながら対数尤度の調和平均を計算しているため、非常に時間がかかる。せめてトピック数を間引きながら増やすことができればよいのだが、関数仕様にそのようなパラメータはない。

そこで、リスト1のcontrolにおけるburninとiterを調整できないか検討してみる。具体的には、LDAをいくつかのトピック数で実行して、対数尤度の推移をプロットし、どのくらいの繰り返し (iter) で飽和するか調べ、最適なburninとiterを決定する。リスト1では、burnin=500、iter=1000になっているが、これを少しでも小さくできれば、optimal_kをもっと高速にできるかもしれない。
これを行うために以下のRスクリプトを作成した。
fitted <- LDA(dpc_dtm, k = 40, method = "Gibbs", control = list(burnin = 500, iter = 1000, keep = 50) )
LL <- data.frame(
  topic_nums = seq(1, 1500, 50),
  logLiks = fitted@logLiks
)
ggplot(LL, aes(x=topic_nums, y=logLiks)) + geom_line()
リスト3. 対数尤度の推移
実行結果を図2に示す。
図2.対数尤度の推移(トピック数=40)
図から分かるように最低でもburnin=250, iter=500は必要で、burnin=500, iter=1000くらいはあった方が良いと考えられる。

したがって、burninとiterの最適化は諦めて、トピック数の間引きを考える。そのために、リスト4のようなRスクリプトを作成した。
topic_nums = seq(5, 100, 5)

burnin = 500
iter = 1000
keep = 100

models <- lapply(topic_nums, function(topic_num){
  model <- LDA(dpc_dtm, k = topic_num, method = "Gibbs", control = list(burnin = burnin, iter = iter, keep = keep) )
  return(model)
})
# モデルを保存
save(models, file = "models-b500-i1000-5-100-5.rda")

# 対数尤度の調和平均のトピック数依存性
logLiks_harmonicMean <- lapply(models, function(model){
  logLiks <- model@logLiks[-c(1:(burnin/keep))]
  return(harmonicMean(logLiks))
})

LL <- data.frame(
  topic_nums = topic_nums,
  logLiks = unlist(logLiks_harmonicMean)
)
ggplot(LL, aes(x=topic_nums, y=logLiks)) + geom_line()
リスト4.対数尤度のトピック数依存性(k=5,10,15,...,100)

リスト4はトピック数を5から始めて5ずつ増やしながら100まで対数尤度の調和平均を計算し、対数尤度の調和平均のトピック数依存性をグラフ化している。結果を図3に示す。
図3.対数尤度の調和平均のトピック数依存性(k=5~100)
図3を見るとトピック数が増えるにつれて対数尤度の調和平均は単調に増加しており、飽和しそうにない。
そこで、今度はトピック数を50から始めて50ずつ増やしながら500まで変化させて対数尤度の調和平均を求めてグラフにした。結果を図4に示す。

図4.対数尤度の調和平均のトピック数依存性(k=50~500)
図からトピック数が350のとき対数尤度の調和平均が最大になっていることがわかる。

図5は繰り返し数(iteration)が進むにつれて対数尤度がどのように変化するかをトピック数ごとに示したものである。

図5.対数尤度の変化
トピック数が大きくなるにつれて対数尤度が飽和するのに要するステップ数が増加する傾向が見られるものの、いずれも収束傾向にある。

Perplexity


これらの結果からトピック数を増やせば増やすほど対数尤度の調和平均は増加することが分かった。原因としてモデルがデータに過適合(オーバーフィッティング)していることが考えられる。

そこで、別の指標で最適なトピック数を求めてみる。ここでは、指標としてPerplexityを考える。ここでは、Topic models: cross validation with loglikelihood or perplexityに従ってPerplexityを計算する。

その前に、Perplexityとは何かについて説明する。
エントロピーとパープレキシティ」によれば、Perplexityとは「単語の平均分岐数を表しており・・・大きいほど,単語の特定が難しく,言語として複雑になる」というものである。1単語あたりのエントロピーをHとした場合、Perplexityは2Hになる。
そこで、さまざまなトピック数kに対してPerplexityを計算し、最小のPerplexityを与えるトピック数 k を求めてみる。
先述したTopic models: cross validation with loglikelihood or perplexity
には"Perplexity is a measure of how well a probability model fits a new set of data."という記述がある。つまり、Perplexityを計算するにはモデルを構築したデータ(すなわち訓練データ)の他に検証データが必要となる。
そこで、DTM形式のDPCデータの75%をランダムに抽出して訓練データとし、残りを検証データとする。

full_data  <- dpc_dtm
n <- nrow(full_data)

splitter <- sample(1:n, round(n * 0.75))
train_set <- full_data[splitter, ]
valid_set <- full_data[-splitter, ]
リスト5.全データを訓練データと検証データに分割(3:1)

次に、訓練データを用いてあるトピック数(ここでは10としている)に対してモデルを構築する。
topic_num = 10

burnin = 500
iter = 1000
keep = 100

model <- LDA(train_set, k = topic_num, method = "Gibbs", control = list(burnin = burnin, iter = iter, keep = keep) )
リスト6.モデルの構築
最後に訓練データと検証データに対してPerplexityを計算する。
PPL_train <- perplexity(model, newdata = train_set)
PPL_valid <- perplexity(model, newdata = valid_set)
リスト7. Perplexityの計算
これをあるトピック数から始めて少しずつ増やしながらPerplexityを計算し、最小を与えるトピック数が最適なトピック数と判断するという考え方である。
リスト8は、これを行うための"Using perplexity and cross-validation to determine a good number of topics"にあるスクリプトである。
install.packages("doParallel")
library(doParallel)
cluster <- makeCluster(detectCores(logical = TRUE) - 1) # leave one CPU spare...
registerDoParallel(cluster)

# load up the needed R package on all the parallel sessions
clusterEvalQ(cluster, {
  library(topicmodels)
})

folds <- 5
splitfolds <- sample(1:folds, n, replace = TRUE)
candidate_k <- c(10, 20, 30, 40, 50, 100, 200, 300, 400, 500) # candidates for how many topics

# export all the needed R objects to the parallel sessions
clusterExport(cluster, c("full_data", "burnin", "iter", "keep", "splitfolds", "folds", "candidate_k"))

# we parallelize by the different number of topics.  A processor is allocated a value
# of k, and does the cross-validation serially.  This is because it is assumed there
# are more candidate values of k than there are cross-validation folds, hence it
# will be more efficient to parallelise
system.time({
  results <- foreach(j = 1:length(candidate_k), .combine = rbind) %dopar%{
    k <- candidate_k[j]
    results_1k <- matrix(0, nrow = folds, ncol = 2)
    colnames(results_1k) <- c("k", "perplexity")
    for(i in 1:folds){
      train_set <- full_data[splitfolds != i , ]
      valid_set <- full_data[splitfolds == i, ]
      
      fitted <- LDA(train_set, k = k, method = "Gibbs",
                    control = list(burnin = burnin, iter = iter, keep = keep) )
      results_1k[i,] <- c(k, perplexity(fitted, newdata = valid_set))
    }
    return(results_1k)
  }
})
stopCluster(cluster)

results_df <- as.data.frame(results)

ggplot(results_df, aes(x = k, y = perplexity)) +
  geom_point() +
  geom_smooth(se = FALSE) +
  ggtitle("5-fold cross-validation of topic modelling with the 'DPC' dataset",
          "(ie five different models fit for each candidate number of topics)") +
  labs(x = "Candidate number of topics", y = "Perplexity when fitting the trained model to the hold-out set")
リスト8.Perplexityから最適なトピック数を決定する

昨日から動かしているけどなかなか終わらない。やっと終わったところで経過時間を見るとなんと所用時間は241541.88(秒)・・・日にち換算で約2.8日。とてもintensiveな計算だ。結果を図6に示す。

図6.Perplexity(リスト8の結果)
トピック数が100あたりまでは急激に減少し、100を超えたあたりからはなだらかに減少し、500に到達しても最小には達せず、引き続き減少傾向にある。
これを図4(対数尤度の調和平均)と比較すると、必ずしも極値は一致しないように思われる。

調べているとこんな記事があった。なんでもldatuningというパッケージがあって、モデルからトピック数を選択するいくつかの指標をはじき出してくれるらしい。その中に「Griffiths2004」というのがあって「要するに、トピック数をT、コーパス全体の単語をwとして、対数尤度logP(w|T)が最大になるTにすればいいじゃん、というアイデアである。」ということらしい。これは、もしかしてリスト4でやっていることなんだろうか?グラフを見ていると似ているなぁと思う。
リスト9にトピック数を50から50ずつ増やしながら500まで変化させたときの各種指標を計算するldatuningのスクリプトを示す。
install.packages("ldatuning")
library("ldatuning")

result <- FindTopicsNumber(
  dpc_dtm,
  topics = seq(from = 50, to = 500, by = 50),
  metrics = c("Griffiths2004", "CaoJuan2009", "Arun2010", "Deveaud2014"),
  method = "Gibbs",
  control = list(seed = 77),
  mc.cores = 2L,
  verbose = TRUE
)

# 可視化
FindTopicsNumber_plot(result)
リスト9.ldatuningによる4つの指標の計算と可視化(k=50~500)

図7にリスト9の実行結果を示す。

図9.4つの指標のトピック数依存性
図9を見ると、Arun2010とCaoJuan2009はトピック数が250~350あたりが最適で、Griffiths2004はトピック数が450でピーク、そしてDeveaud2014はトピック数が100でピークになっている。

LDAの場合、HDP(Hierarchical Dirichlet Process)というのがあって、これを使うとトピック数の自動決定が可能だそうである。しかし、RのHDP実装は無いみたいで、その代わり、GensimがPythonで実装している。加えてRからGensimを呼び出して利用できるようである。

※その後調べているとRにも"R pkg for Hierarchical Dirichlet Process"というのがあるようだ。しかし、"Works on MacOS and Linux, but may not install on Windows."と書かれている。

【補足】

関数optimal_kが何をしているかを知りたいならばソースがGitHubに公開されている。
また、出典となった論文はFindibg Scientific Topicsである。
トピック数の決定については"The input parameters for using latent Dirichlet allocation"や"Topic models: cross validation with loglikelihood or perplexity"で議論されている。

0 件のコメント:

コメントを投稿

ChatGPT は、米国の医師免許試験に太刀打ちできるか?

A Gilson et al.: How Does ChatGPT Perform on the United States Medical Licensing Examination? The Implications of Large Language Models for ...