def self_distillation_loss(labels, logits, model, reg_coef,
                           teacher=None, data=None):
  if teacher is None:
    main_loss = tf.reduce_mean(tf.squared_difference
                               (labels,tf.nn.softmax(logits)))
  else:
    main_loss = tf.reduce_mean(tf.squared_difference
                               (tf.nn.softmax(teacher(data)),
                                tf.nn.softmax(logits)))
  reg_loss = reg_coef*tf.add_n([tf.nn.l2_loss(w) for w in
                                model.trainable_weights])
  total_loss = main_loss + reg_loss
  return total_loss
