from pygcn import load_data
import numpy as np
import networkx as nx
import torch

dataset = 'pubmed'
features = np.loadtxt('data/{}/{}.content'.format(dataset, dataset))
train_pct = 0.4
for i in range(20):
#    adj_test, features, labels, idx_train, idx_val, idx_test = load_data('data/{}/'.format(dataset), '{}'.format(dataset))
#    g = nx.from_numpy_array(adj_test.to_dense().detach().numpy())
    g = nx.read_edgelist('data/{}/{}_{}.cites'.format(dataset, dataset, i), nodetype=int)
    largest_subg = None
    largest_size = -1
    for c in nx.connected_components(g):
        subg = g.subgraph(c)
        if len(subg) > largest_size:
            largest_subg = subg
            largest_size = len(subg)
    subg = largest_subg
    print(len(g), len(subg))
    nx.write_edgelist(subg, 'data/pubmed_connected/pubmed_connected_{}.cites'.format(i), data=False)
    features_i = np.zeros((len(subg), features.shape[1]))
    for j, v in enumerate(subg.nodes()):
        row = np.where(features[:, 0] == v)[0][0]
        features_i[j] = features[row]
    np.savetxt('data/pubmed_connected/pubmed_connected_{}.content'.format(i), features_i, fmt = '%d ' + '%f '*(features.shape[1]-1))
    np.savetxt('data/pubmed_connected/pubmed_connected_{}_train_{:.2f}.content'.format(i, train_pct), features_i, fmt = '%d ' + '%f '*(features.shape[1]-1))

#edges = torch.tensor(nx.to_numpy_array(subg)).nonzero()
#idx = torch.tensor(list(range(len(subg)))).unsqueeze(1)
#idx_feat_lab = torch.cat((idx.float(), subfeats, sublabels.float()), dim=1)
#idx_feat_lab = idx_feat_lab.numpy()
#
#    
#test_pct = 0.5
#valid_pct = 0.1
#train_pct = 1 - test_pct - valid_pct
#
##edges = np.loadtxt('data/cora/cora.cites', dtype=int)
#
#m = edges.shape[0]
#order = np.random.permutation(list(range(m)))
#edges_train = edges[order[:int(m*train_pct)]]
#edges_valid = edges[order[int(m*train_pct):int(int(m*train_pct) + int(m*valid_pct))]]
#edges_test = edges[order[int(int(m*train_pct) + int(m*valid_pct)):]]
#
#np.savetxt('data/{}_connected/{}_connected.cites'.format(dataset, dataset), edges, fmt='%d')
#np.savetxt('data/{}_connected/{}_connected_train_{:.2f}.cites'.format(dataset, dataset, round(train_pct,2)), edges_train, fmt='%d')
#np.savetxt('data/{}_connected/{}_connected_valid_{:.2f}.cites'.format(dataset,dataset, round(train_pct, 2)), edges_valid, fmt='%d')
#np.savetxt('data/{}_connected/{}_connected_test_{:.2f}.cites'.format(dataset, dataset, round(train_pct, 2)), edges_test, fmt='%d')
#np.savetxt('data/{}_connected/{}_connected_train_{:.2f}.content'.format(dataset, dataset, round(train_pct, 2)), idx_feat_lab, fmt='%d')
#np.savetxt('data/{}_connected/{}_connected_test_{:.2f}.content'.format(dataset, dataset, round(train_pct, 2)), idx_feat_lab, fmt='%d')
#np.savetxt('data/{}_connected/{}_connected_valid_{:.2f}.content'.format(dataset, dataset, round(train_pct, 2)), idx_feat_lab, fmt='%d')

#edges = np.loadtxt('data/citeseer/citeseer_old.cites', dtype=str)
#edges_num = np.zeros(edges.shape)
#nodes = list(np.unique(edges))
#for i in range(edges.shape[0]):
#    edges_num[i, 0] = nodes.index(edges[i, 0])
#    edges_num[i, 1] = nodes.index(edges[i, 1])
#
#features = np.loadtxt('data/citeseer/citeseer_old.content', dtype=str)
#for i in range(features.shape[0]):
#    features[i, 0] = str(nodes.index(features[i, 0]))
#np.savetxt('data/citeseer/citeseer.content', features, delimiter='\t', fmt = '%s')
#nodes_with_features = [int(x) for x in np.unique(features[:, 0])]
#
#edges_data = np.zeros(edges_num.shape)
#cur_row = 0
#for i in range(edges_num.shape[0]):
#    if edges_num[i, 0] in nodes_with_features and edges_num[i, 1] in nodes_with_features:
#        edges_data[cur_row] = edges_num[i]
#        cur_row += 1
#edges_data = edges_data[:cur_row]
#np.savetxt('data/citeseer/citeseer.cites', edges_data, delimiter='\t', fmt = '%d')
#
#
#idx_features_labels = np.genfromtxt('data/citeseer/citeseer.content', dtype=np.dtype(str))
#idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
#idx_map = {j: i for i, j in enumerate(idx)}
#edges_unordered = np.genfromtxt("data/citeseer/citeseer.cites",
#                                    dtype=np.int32)
#edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
#                     dtype=np.int32).reshape(edges_unordered.shape)


