import tensorflow as tf
 

class AggLayer(tf.keras.layers.Layer):
    def __init__(self, 
                 aggregators = None,
                 dropout_layer = lambda x:x,
                 L2 = 0.001,
                 num_classes=None,):
        
        super(AggLayer, self).__init__()
        self.aggregators = aggregators
        self.dropout_layer = dropout_layer
        self.regularizer = tf.keras.regularizers.l2(L2)
        self.num_classes=num_classes
        self.weight1 = tf.Variable(tf.zeros([self.num_classes,self.num_classes]))
        self.weight2 = tf.Variable(tf.zeros([self.num_classes,self.num_classes]))
        self.beta=tf.Variable(initial_value=0.01, dtype=tf.float32, trainable=True)
        self.alpha=tf.Variable(initial_value=0.01, dtype=tf.float32, trainable=True)
    
    def build(self, input_shape):
        
        self.in_dim = int(input_shape[-1])
        
        self.gat_kernel_1 = self.add_weight('gat_kernel_1', 
                                            shape = [self.in_dim, 1],
                                            regularizer = self.regularizer)
        
        self.gat_b_1 = self.add_weight('gat_b_1', 
                                       initializer='zeros',
                                       shape = [1, 1],
                                       regularizer = self.regularizer)
        
        self.gat_kernel_2 = self.add_weight('gat_kernel_2', 
                                            shape = [self.in_dim, 1],
                                            regularizer = self.regularizer)
        
        self.gat_b_2 = self.add_weight('gat_b_2',
                                       initializer='zeros',
                                       shape = [1, 1],
                                       regularizer = self.regularizer)
        
# 定义功能，相当于Lambda层的功能函数
    def call(self, input, training = True):
        
        f_1 = tf.matmul(input, self.gat_kernel_1) + self.gat_b_1
        f_2 = tf.matmul(input, self.gat_kernel_2) + self.gat_b_2
        
        alpha_list = [f_1 + f_2]
        head_list = [input]
        
        for agg in self.aggregators:
                
            head = self.dropout_layer(input)
            
            for _ in range(2):
                
                head = tf.sparse.sparse_dense_matmul(agg, head)
                support = (1-self.beta)*(1-self.alpha)*head + self.beta*tf.matmul(head, self.weight1) 
                head=support
                initial = (1-self.beta)*(self.alpha)*input + self.beta*tf.matmul(input, self.weight2) 
                head=support+initial   
            head_list.append(head)
            
            alpha = f_1 + tf.matmul(head, self.gat_kernel_2) + self.gat_b_2
            alpha_list.append(alpha)
            
        alpha_tensor = tf.stack(alpha_list)
        alpha_tensor = tf.nn.leaky_relu(alpha_tensor)
        alpha_tensor = tf.nn.softmax(alpha_tensor, 0)     
        head_tensor = tf.stack(head_list)       
        head_tensor = self.dropout_layer(head_tensor)
        gat_tensor = tf.multiply(alpha_tensor, head_tensor)
        gat_tensor = tf.reduce_sum(gat_tensor, 0)
                
        return gat_tensor