从sksurv 库中的concordance_index_censored函数说起
sksurv库的安装和简单使用参见生存分析序章3——生存分析之Python篇:介绍生存分析和scikit-survival库-CSDN博客
需要注意的是sksurv库有时候对于python版本的要求比较严格
进入正题,这个函数是用于计算c-index,即一致性指数,用于评价模型的预测能力,所有病人对子中预测结果与实际结果一致的对子所占的比例。它估计了预测结果与实际观察到的结果相一致的概率。
c指数的计算方法是:把所研究的资料中的所有研究对象随机地两两组成对子。以生存分析为例,对于一对病人,如果生存时间较长的一位的预测生存时间也长于另一位的预测生存时间,或预测的生存概率高的一位的生存时间长于生存概率低的另一位,则称之为预测结果与实际结果一致。
计算步骤如下:
1、产生所有的病例配对。若有n个观察个体,则所有的对子数应为$C_n^2$(组合数)
2、保留“有效对”:(1)都观测到死亡事件发生且二者时间不同(2)观测到死亡时间的个体生存时间较短。
3、计算有用对子中,预测结果和实际相一致的对子数, 即具有较坏预测结果个体的实际观察时间较短。
4、计算,C-index = 一致对子数/有用对子数。
在sksurv库中,这个函数的输入包含三个核心参数,在MCAT的代码中分别对应
- event indicator——$(1-censorships).astype(bool)$
- event time——$all$ $event$ $times$
- estimate——$all$ $risk$ $scores$
这里贴一段MCAT中的代码,经典永流传
for batch_idx, (data_WSI, data_omic, label, event_time, c) in enumerate(loader):
data_WSI, data_omic = data_WSI.to(device), data_omic.to(device)
label = label.to(device)
c = c.to(device)
hazards, S, Y_hat, _, _ = model(x_path=data_WSI, x_omic=data_omic) # return hazards, S, Y_hat, A_raw, results_dict
loss = loss_fn(hazards=hazards, S=S, Y=label, c=c)
loss_value = loss.item()
if reg_fn is None:
loss_reg = 0
else:
loss_reg = reg_fn(model) * lambda_reg
risk = -torch.sum(S, dim=1).detach().cpu().numpy()
all_risk_scores[batch_idx] = risk
all_censorships[batch_idx] = c.item()
all_event_times[batch_idx] = event_time
train_loss_surv += loss_value
train_loss += loss_value + loss_reg
if (batch_idx + 1) % 100 == 0:
print('batch {}, loss: {:.4f}, label: {}, event_time: {:.4f}, risk: {:.4f}, bag_size: {}'.format(batch_idx, loss_value + loss_reg, label.item(), float(event_time), float(risk), data_WSI.size(0)))
# backward pass
loss = loss / gc + loss_reg
loss.backward()
if (batch_idx + 1) % gc == 0:
optimizer.step()
optimizer.zero_grad()
# calculate loss and error for epoch
train_loss_surv /= len(loader)
train_loss /= len(loader)
# c_index = concordance_index(all_event_times, all_risk_scores, event_observed=1-all_censorships)
c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0]
这段代码中,输入到c-index函数的三个参数分别是1-censorship,event time和risk score
- event time很好理解,就是记录的死亡发生时间,在没有记录到死亡时,则以最后的检查时间代替
- 而censorship则是一个较复杂的概念,在MCAT的补充材料中,$c \in \{0,1\}$表示“right uncensorship status (death observed)”,也即在这篇论文中,c=0表示病人已死亡,c=1表示病人在最后观察期内仍旧存活。在计算c-index时,输入值为$1-c$,即删失状态,表示在研究期间观测时间没有发生(没死)。
- risk score是需要从模型输出经过一系列计算得到的,这里给出代码及注释
logits = model(x)#这里是一个[1x4]的vector,表示死亡时间落在各个区间的概率
hazards = torch.sigmoid(logits)#对logits做归一化,使得每个数值压缩到01之间
S = torch.cumprod(1 - hazards, dim=1)#torch.cumprod函数按维度计算前缀积,这里用1去减是计算生存概率,则S表示从开始到每个区间没有发生死亡的概率
risk = np.asscalar(-torch.sum(S, dim=1).cpu().numpy())#求和是计算所有区间内的生存概率总和,取负数则得到风险评分
all_risk_scores.append(risk)#得到最终可用于计算的risk分数
生存预测中的研究定义
参考blog.51cto.com/16099346/9101912
删失(censored)
删失即在观测期间没有预期事件的发生,在生存预测中就是在观察时间内病人未死亡
发生删失的原因:
- 病人还活着(right censored)
- 发生其他事件导致无法继续观察(死于其他疾病)
生存概率、风险概率
- 生存概率$S(t)=P(T \geq t)$,其中T为病人死亡时间变量
- 风险概率(hazard probability)$h(t)=\lim_{\delta(t) \rightarrow 0} \frac{P(t \leq T \leq t + \delta(t)|T \geq t)}{\delta(t)}$
- 累计风险(cummulative hazard)$H(t)=-\log(S(t))$
推导如下:
定义$F(t)=1-S(t)=P(T\leq t)$ ,即风险累计分布函数,则对应的概率密度$f(t)=\frac{dF(t)}{dt}$
则$h(t)=\lim_{\delta(t) \rightarrow 0} \frac{P(t \leq T \leq t + \delta(t)|T \geq t)}{\delta(t)}=\lim_{\delta(t) \rightarrow 0} \frac{F(t+\delta(t)) – F(t)}{\delta(t)S(t)}=\frac{f(t)}{S(t)}$
于是有$\frac{dH(t)}{dt}=-\frac{1}{S(t)}\cdot (dS(t))=-\frac{1}{S(t)}\cdot (-f(t))=\frac{f(t)}{S(t)}=h(t)$
Kaplan-Meier Survival estimate
这是一种非参的生存概率估计方法,数据点为观测到的生存时间
假设有各个离散时间点的病人数据,定义生存概率$S(t_n)=S(t_{n-1})(1-\frac{d_n}{r_n})$
其中$d_n$指时间$t_n$发生的事件(死的人数),$r_n$指快到$t_n$时还存活的人数(剔除删失人数)
Cox比例风险回归模型
参考最直观的理解Cox模型-生存分析Survival Analysis-Chapter 3-Cox模型及其特点(a) – 知乎
该模型为多变量的生存模型,可以解决KM法无法分析连续变量的问题
定义:
- 基线风险方程$h_0(t)$为一个对时间的非负方程,与特征向量无关
- $X_i$为实例$i$的特征向量,大部分时候与时间无关,及time independent变量
- $\beta$为参数向量
则可以定义风险概率$h(t,X_i)=h_0(t) \times e^{X_i \beta}$
这里值得注意的是,风险概率函数在所有的X都为0时,这个基线风险方程仍是一个未确定的函数,因此Cox模型试一个半参数模型
要说半参数回归,就要明白参数回归和非参数回归。**参数回归是事先假定模型的形式,然后用数据去估计这个模型的系数。而非参数回归则是不假定模型形式,直接从数据来拟合模型。**参数回归最基本的是线性模型,非参数回归最简单的最近邻方法。而半参数回归则是,模型中有一部分的结构是已知的,需要估计参数,而另外一部分结构未知。半参数回归种类非常多,除了cox回归外,MARS,gam应该都是半参数回归。
现在对Cox模型做极大似然估计,得到$\hat\beta_{i}$ ,则对于每一个实例$i$,定义似然函数:
$$
L_i(\beta) = \frac{h(T_i, X_i)}{\sum_{j: T_j \geq T_i} h(T_i, X_j)}
$$
由于要在每一个实例的对应时间处计算当前的风险集(即还没死的人),Cox模型的似然是基于观察到的事件的发生顺序(这个意思是与具体的时间值并无关系,只考虑先后顺序)
对每个实例的似然函数,由于分子分母均有$h_0(t)$,可以直接约分消去,进一步,完整的似然函数就是:$L=\prod_{i=1}^{k}L_i=\prod_{i=1}^{k} \frac{e^{X_i\beta}}{\sum_{j: T_j \geq T_i} e^{X_j\beta}}$
求解就是对$\beta$求导
实验细节
数据集处理
在实际实验中,生存预测常用TCGA官网数据集,但是官网对于censorship和event time的标注比较乱,因此建议从三方网站获取数据label文件,配合从TCGA官网下载的SVS数据,这里提供一种从linkedomics网站下载的方法:
在图中右侧Data Download下载,下载得到的是tsi格式,建议改成tsv方便读取
以下代码用于数据预处理
# 已经下载了BLCA.tsv
# prepare clinical data
clinical = pd.read_csv(clinical_file, sep='\t')
# 剔除时间为空的样本
clinical = clinical.T
clinical.columns = clinical.iloc[0]
clinical = clinical.drop(clinical.index[0], 0)
clinical = clinical[clinical['overall_survival'].notnull()]
# divide dat into 4 bins 参考MCAT
label_col = 'overall_survival'
assert label_col in clinical.columns
clinical[label_col] = clinical[label_col].astype(float)
clinical['status'] = clinical['status'].astype(int)
patients_df = clinical.copy()
print(patients_df[label_col].min())
print(patients_df[label_col].max())
print(patients_df['status'].values)
# 取死亡病例
uncensored_df = patients_df[patients_df['status'] == 1]
eps=1e-6
n_bins = 4
disc_labels, q_bins = pd.qcut(uncensored_df[label_col], q=n_bins, retbins=True, labels=False, duplicates='drop')
# 更新bins上下界
q_bins[-1] = patients_df[label_col].max() + eps
q_bins[0] = patients_df[label_col].min() - eps
print(q_bins)
# 重新划分bins
disc_labels, q_bins = pd.cut(patients_df[label_col], bins=q_bins, retbins=True, labels=False, right=False, include_lowest=True)
print(disc_labels)
patients_df.insert(2, 'label', disc_labels.values.astype(int)) print(patients_df)
数据集处理后对照TCGA官网图片数据集即可,MCAT是对标签文件做了去重,对每个病例只取一张病理图,个人理解每个病例取所有图片也可以,因为在计算C-index的时候同一个病例的不同图片并不会被考虑进去
MCAT中的魔改损失函数
背景:将数据集中的survival time分割成四个区间$[t_0,t_1), [t_1,t_2),[t_2,t_3),[t_3,t_4)$,并定义每个病人的标签:$$T_j=r \ \ \ if \ \ T_{j,continous} \in [t_r,t_{r+1})\ for \ r \in {0,1,2,3}$$在给定输入:第j个病人的GT,即$Y_j$;bag-level的特征$h_{final_j}$;以及如上文定义的$f_{hazard}$和$f_{surv}$
则训练中的生存损失定义如下:
$$
L_{surv}=(1-\beta)L + \beta L_{uncensored}
$$
其中:
$$
\begin{align}
L=-l=&-c_j\cdot \log(f_{surv}(Y_j|h_{final_j})) \\
&- (1-c_j)\cdot \log(f_{surv}(Y_j – 1|h_{final_j})) \\
&- (1-c_j)\cdot \log(f_{hazard}(Y_j|h_{final_j}))
\end{align}
$$
$$
\begin{align}
L_{uncensored} = &-(1-c_j) \cdot \log(f_{surv}(Y_j – 1 | h_{final_j})) \\
&- (1-c_j) \cdot \log(f_{hazard}(Y_j | h_{final_j}))
\end{align}
$$
这里面的$L_{uncensored}$用于增加c为0的数据(病人观察到死亡)的贡献
在$L$中则是包含两部分:
- 对于没死(c=1)的病人,计算预测的存活时间在$Y_j$之后的概率
- 对于已死(c=0)的病人,计算预测的存活时间在$Y_j-1$和$Y_j$之间的概率
而所谓的增加死亡病人的这个Loss,就是把这部分病人的Loss值从$1-\beta$的系数变成了$1$的系数
这里放一下魔改过后的两种Loss,NLL损失和CE损失
def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
batch_size = len(Y)
Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
c = c.view(batch_size, 1).float() #censorship status, 0 or 1
if S is None:
S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
# without padding, S(0) = S[0], h(0) = h[0]
S_padded = torch.cat([torch.ones_like(c), S], 1) #S(-1) = 0, all patients are alive from (-inf, 0) by definition
# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
#h[y] = h(1)
#S[1] = S(1)
uncensored_loss = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
censored_loss = - c * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps))
neg_l = censored_loss + uncensored_loss
loss = (1-alpha) * neg_l + alpha * uncensored_loss
loss = loss.mean()
return loss
def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
batch_size = len(Y)
Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
c = c.view(batch_size, 1).float() #censorship status, 0 or 1
if S is None:
S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
# without padding, S(0) = S[0], h(0) = h[0]
# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
#h[y] = h(1)
#S[1] = S(1)
S_padded = torch.cat([torch.ones_like(c), S], 1)
reg = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y)+eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
ce_l = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (1 - c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps))
loss = (1-alpha) * ce_l + alpha * reg
loss = loss.mean()
return loss