#  packages for PReBiM
# install.packages("devtools")
# devtools::install_github("xue-hr/MRCD")
# install.packages("AER")
# install.packages("MASS")
library(AER)
library(MASS)
library(MRCD)

#' @input 
#' @param X  # Exposure variable.
#' @param Y  # Outcome variable.
#' @param G  # Candidate instrumental variables (IVs).
#' @param alpha  # Significance level.
#' @param direction  # The direction between the exposure variable X and the outcome variable Y in the model can be either unidirectional or bidirectional.
#' @param B  # The length of finding the shortest valid set.
#' @param W  # The maximum length of a set containing valid IVs.

#' @return A list containing several outcomes of interest:
#'   \item{\code{Valid_IVs_XtoY}}{A valid IV set identified by our method for X to Y.}
#'   \item{\code{Valid_IVs_YtoX}}{A valid IV set identified by our method for Y to X.}
#'   \item{\code{betahat_XtoY}}{An estimated effect calculated using the identified IVs for X to Y.}
#'   \item{\code{betahat_YtoX}}{An estimated effect calculated using the identified IVs for Y to X.}

#' @export
# ----------------- Main function -----------------------
PReBiM <- function(X, Y, G, alpha = 0.01, direction = 2, B = 2, W = 3, Test_method = "pearson") {
  # ----------------------- obtain g, indices from data ------------------------------
  Y <- as.vector(Y)
  X <- as.vector(X)
  G <- as.matrix(G)
  g <- ncol(G)  # Get the number of potential instrumental variables (IVs).
  indices <- seq_len(g)  # Generate indices for reference.
  output <- list()  # Save the final results.
  
  # ----------------------- function of Testing correlation(Gj, PR) -----------------------
  if(Test_method == "pearson"){
    TestingCorr <- function(Gj, PR){
      PR <- unlist(PR) 
      iv.reg.xtoy <- ivreg(formula = Y ~ X |G[, PR], data = data.frame(G,X,Y))
      beta.2sls <- iv.reg.xtoy$coefficients[2]
      residual <- Y - X * beta.2sls
      result <- cor.test(residual, G[, Gj], method ="pearson")  
      return(list(p.value = result$p.value, estimate = result$estimate))
    }
    
  }
  else if (Test_method == "spearman"){
    TestingCorr <- function(Gj, PR){
      PR <- unlist(PR) 
      iv.reg.xtoy <- ivreg(formula = Y ~ X |G[, PR], data = data.frame(G,X,Y))
      beta.2sls <- iv.reg.xtoy$coefficients[2]
      residual <- Y - X * beta.2sls
      result <- cor.test(residual, G[, Gj], method ="spearman")  
      return(list(p.value = result$p.value, estimate = result$estimate))
    }
  }
  else if (Test_method == "kendall"){
    TestingCorr <- function(Gj, PR){
      PR <- unlist(PR) 
      iv.reg.xtoy <- ivreg(formula = Y ~ X |G[, PR], data = data.frame(G,X,Y))
      beta.2sls <- iv.reg.xtoy$coefficients[2]
      residual <- Y - X * beta.2sls
      result <- cor.test(residual, G[, Gj], method ="kendall")  
      return(list(p.value = result$p.value, estimate = result$estimate))
    }
    
  }
  
  # ----------------------- function of extending MostSig to length W -----------------------
  extensive_MostSig <- function(MostSig, indices, alpha = 0.01, W) {
    while ((length(MostSig) < W) & (length(indices)>length(MostSig)) ) {  
      
      updatedMostSig <- unlist(MostSig)
      update_indices <- setdiff(indices, updatedMostSig)
      MostSig_pairIV_P <- c() 
      for (IV in 1:length(update_indices)) {
        Gj <- update_indices[IV]
        p.value <- TestingCorr(Gj, updatedMostSig)$p.value
        MostSig_pairIV_P <- c(MostSig_pairIV_P, p.value)
      }
      
      if(any(MostSig_pairIV_P > alpha)){
        father <- c()
        for (f in 1:length(update_indices)) {
          Gj <- update_indices[f]
          PR <- updatedMostSig
          corr <- TestingCorr(Gj, PR)$estimate
          father[f] <- abs(corr)
        }
        min_index <- which.min(father)
        updatedMostSig <- append(updatedMostSig, update_indices[min_index])
        MostSig <- updatedMostSig
        
      }
      else{
        break
      }
    }
    return(MostSig)
  }
  
  # ----------------------- Direction 1: unidirectional -----------------------
  if(direction == 1){
    
    # Filtering
    for(Gi in indices){
      p.value <- cor.test(X, G[, Gi])$p.value
      if(p.value > 0.001){
        indices <- setdiff(indices, Gi)
      }
    }

    
    while (B <= W) {
      # obtain Valid IV sets (length = B)
      AllCom <- combn(indices, B)
      Valid_IVs_corr <- numeric(ncol(AllCom))
      
      for (com in 1:ncol(AllCom)) {
        temporaryCom <- AllCom[, com]
        temporaryCom_corr <- numeric(B)
        for (b in 1:B) {
          Gj <- temporaryCom[b]
          PR <- setdiff(temporaryCom, Gj)
          temporaryCom_corr[b] <- abs(TestingCorr(Gj, PR)$estimate)
        }
        Valid_IVs_corr[com] <- mean(temporaryCom_corr)
      }
      
      min_index <- which.min(Valid_IVs_corr)
      MostSig <- AllCom[, min_index]
      
      # Extend MostSig to length W
      if ((length(MostSig) < W) && (length(indices) > B)) {
        Big_MostSig <- extensive_MostSig(MostSig, indices, alpha = alpha, W)
        iv.reg.xtoy <- ivreg(formula = Y ~ X | G[, Big_MostSig], data = data.frame(G, X, Y))
        beta.2sls.xtoy <- iv.reg.xtoy$coefficients[2]
        
        output <- list(Valid_IVs_XtoY = Big_MostSig, Valid_IVs_YtoX = NULL, betahat_XtoY = beta.2sls.xtoy, betahat_YtoX = 0)
        
        return(output)
      }
      else {
        iv.reg.xtoy <- ivreg(formula = Y ~ X | G[, MostSig], data = data.frame(G, X, Y))
        beta.2sls.xtoy <- iv.reg.xtoy$coefficients[2]
      
        output <- list(Valid_IVs_XtoY = MostSig, Valid_IVs_YtoX = NULL, betahat_XtoY = beta.2sls.xtoy, betahat_YtoX = 0)
        return(output)
      }
    }
    
    # If the loop completes without finding Valid_IVs, return the last output
    return(output)
  }
  
  # ----------------------- Direction 2: bi-directional -----------------------
  if(direction == 2){
    
    while ((length(output) < 2) & (length(indices) > B)) { # stop condition
      
      AllCom <- combn(indices, B)
      Valid_IVs_corr <- numeric(ncol(AllCom))
      
      for (com in 1:ncol(AllCom)) {
        temporaryCom <- AllCom[, com]
        temporaryCom_corr <- numeric(B)
        for (b in 1:B) {
          Gj <- temporaryCom[b]
          PR <- setdiff(temporaryCom, Gj)
          temporaryCom_corr[b] <- abs(TestingCorr(Gj, PR)$estimate)
        }
        Valid_IVs_corr[com] <- mean(temporaryCom_corr)
      }
      
      min_index <- which.min(Valid_IVs_corr)
      MostSig <- AllCom[, min_index]
      
      # entend MostSig to length W
      Big_MostSig <- extensive_MostSig(MostSig, indices, alpha = alpha, W)
      
      if(length(output)<1){
        indices <- setdiff(indices, unlist(Big_MostSig))
        output <- append(output, list(Big_MostSig))
      }
      else{
        # Check if it can be merged.
        label_union <- 0
        pre_unionset <- union(output, Big_MostSig)
        unionset <- unlist(pre_unionset)
        for(testunion in 1:length(unionset)) {
          Gj <- unionset[testunion]
          PR <- setdiff(unionset, Gj)
          p.value <- TestingCorr(Gj, PR)$p.value
          if(p.value <= alpha) {
            label_union <- 1
            break
          }
        } 
        #If merging is possible, discard and look for the IV on the other side
        if(label_union == 0) { 
          indices <- setdiff(indices, unionset)
        }
        #If merging is not possible, add it to the output
        if(label_union == 1) { 
          Valid_IVs <- unlist(Big_MostSig)
          indices <- setdiff(indices, Valid_IVs)
          output <- append(output, list(Valid_IVs))
        }
      }
    }
    
    if(length(output) == 1){
      indices = seq_len(g)
      AllCom <- combn(indices, B)
      corr_father <- numeric(ncol(AllCom))
      for (com in 1:ncol(AllCom)) {
        temporaryCom <- AllCom[, com]
        corr_child <- numeric(B)
        for (b in 1:B) {
          Gj <- temporaryCom[b]
          PR <- setdiff(temporaryCom, Gj)
          corr_child[b] <- abs(TestingCorr(Gj, PR)$estimate)
        }
        corr_father[com] <- mean(corr_child)
      }
      min_indices <- order(corr_father)[1:2]
      min1_index_corr <- min_indices[1]
      min2_index_corr <- min_indices[2]
      output <- list(AllCom[, min1_index_corr], AllCom[, min2_index_corr])
    }
    
    # Differentiate directions  
    if(length(output) > 1) {
      K_set0 <- calculating_K(G[, output[[1]]], X, Y, V_T1 = NULL, V_T2 = NULL)
      K_set1 <- calculating_K(G[, output[[2]]], X, Y, V_T1 = NULL, V_T2 = NULL)
      kyx0 <- abs(K_set0[[1]])
      kyx1 <- abs(K_set1[[1]])
      if(kyx0 > kyx1)
      {
        G_XtoY <- output[[2]]
        G_YtoX <- output[[1]]
      }
      else
      {
        G_XtoY <- output[[1]]
        G_YtoX <- output[[2]]
      }
      
      iv.reg.xtoy <- ivreg(formula = Y ~ X |G[, G_XtoY], data = data.frame(G,X,Y))
      beta.2sls.xtoy <- iv.reg.xtoy$coefficients[2]
      iv.reg.ytox <- ivreg(formula = X ~ Y |G[, G_YtoX], data = data.frame(G,X,Y))
      beta.2sls.ytox <- iv.reg.ytox$coefficients[2]
      
      output <- list(Valid_IVs_XtoY = G_XtoY, Valid_IVs_YtoX = G_YtoX, betahat_XtoY = beta.2sls.xtoy, betahat_YtoX = beta.2sls.ytox)
    }
  }
  
  return(output)
}

#---------------------- P4 required -------------------------------------------

#' Calculate Asymptotic Variance Matrix
#'
#' Calculate the asymptotic covariance matrix of sample correlations of n
#' SNPs with X, Use Theorem 2 in
#' "The Asymptotic Variance Matrix of the Sample Correlation Matrix".
#'
#' @param SNP n by p matrix, genotype data for n individuals of p SNPs.
#' @param rho (n+1) by (n+1) matrix, sample correlation matrix of n SNPs and X.
#'
#' @return n by n matrix, which is the asymptotic covariance matrix of
#' sample correlations of n SNPs with X.
#' @export
calculate_asymptotic_variance <- function(SNP,rho)
{
  n = ncol(rho)-1
  SNP = scale(SNP)
  SNP = SNP * sqrt(nrow(SNP)) / sqrt(nrow(SNP)-1)
  M_s = generate_M_s(n + 1)
  M_d = generate_M_d(n + 1)
  ###
  B = ( kronecker( diag(n+1) , rho) )
  C = eigenMapMatMult(M_s, B)
  C = eigenMapMatMult(C,M_d)
  M_1 = diag((n+1)^2) - C
  
  V = generate_V(SNP,rho)
  
  C = eigenMapMatMult(M_1,V)
  B = t(M_1)
  asymp_cov = eigenMapMatMult(C,B)
  
  target_ind1 = n * (n+1) + 1
  target_ind2 = (n+1)^2-1
  
  return( asymp_cov[(target_ind1:target_ind2) , (target_ind1:target_ind2)] )
}

#' Generate matrix M_s
#'
#' Generate matrix M_s in formula 2.9 in
#' "The Asymptotic Variance Matrix of the Sample Correlation Matrix".
#'
#' @param n dimension.
#'
#' @return M_s M_s matrix.
#' @export
generate_M_s <- function(n)
{
  # For n, generate matrix M_s in formula 2.9
  K = matrix(0,n^2,n^2)
  for(i in 1:n)
  {
    for(j in 1:n)
    {
      ind_row = (i-1)*n + j
      ind_col = (j-1)*n + i
      K[ind_row,ind_col] = K[ind_row,ind_col] + 1
    }
  }
  M_s = 1/2 * (diag(n^2) + K)
  return(M_s)
}

#' Generate matrix M_d
#'
#' Generate matrix M_d in formula 2.13 in
#' "The Asymptotic Variance Matrix of the Sample Correlation Matrix".
#'
#' @param n dimension.
#'
#' @return M_d M_d matrix.
#' @export
generate_M_d <- function(n)
{
  # For n, generate matrix M_d in formula 2.13
  K = matrix(0,n^2,n^2)
  for(i in 1:n)
  {
    ind_rc = (i-1)*n+i
    K[ind_rc,ind_rc] = K[ind_rc,ind_rc] + 1
  }
  
  return(K)
}

#' Generate matrix V
#'
#' Generate matrix V in formula 3.6 in
#' "The Asymptotic Variance Matrix of the Sample Correlation Matrix".
#'
#' @param SNP n by p matrix, genotype data for n individuals of p SNPs. And it is
#' standardized with column means 0 and variances 1.
#' @param rho (n+1) by (n+1) matrix, sample correlation matrix of n SNPs and X.
#'
#' @return V V matrix.
#' @export
generate_V <- function(SNP,rho)
{
  n = ncol(SNP)
  Sigma = rho[1:n,1:n]
  inv_Sigma = solve(Sigma,tol = 0)
  rho_X = as.matrix(rho[1:n,(n+1)])
  alpha =  inv_Sigma %*% rho_X
  e2 = 1 - t(rho_X) %*% inv_Sigma %*% rho_X
  
  if(nrow(alpha)>1)
  {
    SNP_alpha = SNP %*% diag(as.numeric(alpha))
  } else{
    SNP_alpha = SNP * as.numeric(alpha)
  }
  S = rowSums(SNP_alpha)
  ###
  
  sampleszie = length(S)
  sim_X = S + rnorm(sampleszie,0,sqrt(e2))
  sim_X = scale(sim_X)*sqrt(sampleszie) / sqrt(sampleszie-1)
  M_SX = cbind(SNP,sim_X)
  
  BigM = NULL
  for(i in 1:(n+1))
  {
    BigM = cbind(BigM,M_SX[,i]*M_SX)
  }
  ###
  B = t(BigM)
  C = eigenMapMatMult(B,BigM)
  V = C / sampleszie
  ###
  #V = t(BigM)%*%BigM / sampleszie
  ###
  
  vSigma = as.matrix(c(rho))
  V = V - vSigma%*%t(vSigma)
  
  return(V)
}
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

armaMatMult <- function(A, B) {
  .Call('_MRCD_armaMatMult', PACKAGE = 'MRCD', A, B)
}

eigenMatMult <- function(A, B) {
  .Call('_MRCD_eigenMatMult', PACKAGE = 'MRCD', A, B)
}

eigenMapMatMult <- function(A, B) {
  .Call('_MRCD_eigenMapMatMult', PACKAGE = 'MRCD', A, B)
}

calculating_K <- function(G_valid,X,Y,V_T1 = NULL,V_T2 = NULL)
{
  SNP = G_valid  
  g = ncol(SNP)
  #
  N_T1 = nrow(SNP)  
  N_T2 = nrow(SNP) 
  T1_r = cor(SNP,X)  
  T2_r = cor(SNP,Y)  
  #
  rho_T1 = matrix(0, ncol = (g+1), nrow = (g+1))  
  rho_T1[1:g,1:g] = cor(SNP)  
  rho_T1[1:g,g+1] = T1_r 
  rho_T1[g+1,1:g] = T1_r
  rho_T1[g+1,g+1] = 1
  
  #
  rho_T2 = matrix(0, ncol = (g+1), nrow = (g+1))
  rho_T2[1:g,1:g] = cor(SNP)
  rho_T2[1:g,g+1] = T2_r
  rho_T2[g+1,1:g] = T2_r
  rho_T2[g+1,g+1] = 1
  #
  if(is.null(V_T1))
  {
    V_T1 = calculate_asymptotic_variance(SNP,rho_T1)  
  }
  if(is.null(V_T2))
  {
    V_T2 = calculate_asymptotic_variance(SNP,rho_T2)  
  }
  #
  
  T1_r <- as.vector(T1_r)
  T2_r <- as.vector(T2_r)
  jacobian = cbind( diag(1/T1_r) ,
                    -diag(T2_r/T1_r^2) )  
  combined_V = rbind( cbind(V_T2,matrix(0,ncol = g, nrow = g)) / mean(N_T2),
                      cbind(matrix(0,ncol = g, nrow = g),V_T1) / mean(N_T1))   
  V = jacobian %*% combined_V %*% t(jacobian)
  inv_V = solve(V, tol = 0)
  est_vec = T2_r/T1_r
  gls_est1 = sum(inv_V %*% est_vec) / sum(inv_V)
  
  T1_r <- as.vector(T1_r)
  T2_r <- as.vector(T2_r)
  jacobian = cbind( diag(1/T2_r) , -diag(T1_r/T2_r^2) )
  combined_V = rbind( cbind(V_T1,matrix(0,ncol = g, nrow = g)) / mean(N_T1),
                      cbind(matrix(0,ncol = g, nrow = g),V_T2) / mean(N_T2)  # 
  )
  V = jacobian %*% combined_V %*% t(jacobian)
  inv_V = solve(V, tol = 0)
  est_vec = T1_r/T2_r
  gls_est2 = sum(inv_V %*% est_vec) / sum(inv_V)
  
  return(list(gls_est1, gls_est2))
}

