HloModule a_inference_grad_loss_49534__XlaMustCompile_true_config_proto___n_007_n_003CPU_020_001_n_007_n_003GPU_020_0012_005__0010J_0008_001_202_001_000__executor_type____.27

%fused_computation (param_0.5: f32[5], param_1.7: f32[5], param_2.5: f32[5], param_3.3: f32[5]) -> f32[5,5] {
  %constant_1 = f32[] constant(2), metadata={op_type="Mul" op_name="PartitionedCall_1/gradients/pow_grad/mul_1"}
  %broadcast.4 = f32[5]{0} broadcast(f32[] %constant_1), dimensions={}, metadata={op_type="Mul" op_name="PartitionedCall_1/gradients/pow_grad/mul_1"}
  %param_1.7 = f32[5]{0} parameter(1)
  %param_2.5 = f32[5]{0} parameter(2)
  %add.1 = f32[5]{0} add(f32[5]{0} %param_1.7, f32[5]{0} %param_2.5), metadata={op_type="AddV2" op_name="PartitionedCall/add"}
  %param_0.5 = f32[5]{0} parameter(0)
  %subtract.0 = f32[5]{0} subtract(f32[5]{0} %add.1, f32[5]{0} %param_0.5), metadata={op_type="Sub" op_name="PartitionedCall/sub_0"}
  %multiply.3 = f32[5]{0} multiply(f32[5]{0} %broadcast.4, f32[5]{0} %subtract.0), metadata={op_type="Mul" op_name="PartitionedCall_1/gradients/pow_grad/mul_1"}
  %broadcast.3 = f32[5,5]{1,0} broadcast(f32[5]{0} %multiply.3), dimensions={0}
  %param_3.3 = f32[5]{0} parameter(3)
  %broadcast.5 = f32[5,5]{1,0} broadcast(f32[5]{0} %param_3.3), dimensions={1}
  ROOT %multiply.2 = f32[5,5]{1,0} multiply(f32[5,5]{1,0} %broadcast.3, f32[5,5]{1,0} %broadcast.5), metadata={op_type="MatMul" op_name="PartitionedCall_1/gradients/MatVec/MatMul_grad/MatMul"}
}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_lhs = f32[] parameter(0)
  %scalar_rhs = f32[] parameter(1)
  ROOT %add = f32[] add(f32[] %scalar_lhs, f32[] %scalar_rhs)
}

%fused_computation.1 (param_0.4: f32[5,5], param_1.8: f32[5]) -> f32[5] {
  %param_0.4 = f32[5,5]{1,0} parameter(0)
  %param_1.8 = f32[5]{0} parameter(1)
  %broadcast.6 = f32[5,5]{1,0} broadcast(f32[5]{0} %param_1.8), dimensions={1}
  %multiply.4 = f32[5,5]{1,0} multiply(f32[5,5]{1,0} %param_0.4, f32[5,5]{1,0} %broadcast.6)
  %constant_2 = f32[] constant(0)
  ROOT %reduce.1 = f32[5]{0} reduce(f32[5,5]{1,0} %multiply.4, f32[] %constant_2), dimensions={1}, to_apply=%scalar_add_computation, metadata={op_type="Squeeze" op_name="PartitionedCall/MatVec/Squeeze"}
}

ENTRY %a_inference_grad_loss_49534__XlaMustCompile_true_config_proto___n_007_n_003CPU_020_001_n_007_n_003GPU_020_0012_005__0010J_0008_001_202_001_000__executor_type____.27 (arg0.1: f32[5,5], arg1.2: f32[5], arg2.3: f32[5], arg3.4: f32[5]) -> f32[5,5] {
  %arg3.4 = f32[5]{0} parameter(3), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %arg0.1 = f32[5,5]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %arg1.2 = f32[5]{0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %fusion.1 = f32[5]{0} fusion(f32[5,5]{1,0} %arg0.1, f32[5]{0} %arg1.2), kind=kLoop, calls=%fused_computation.1, metadata={op_type="Squeeze" op_name="PartitionedCall/MatVec/Squeeze"}
  %arg2.3 = f32[5]{0} parameter(2), parameter_replication={false}, metadata={op_name="XLA_Args"}
  ROOT %fusion = f32[5,5]{1,0} fusion(f32[5]{0} %arg3.4, f32[5]{0} %fusion.1, f32[5]{0} %arg2.3, f32[5]{0} %arg1.2), kind=kLoop, calls=%fused_computation, metadata={op_type="MatMul" op_name="PartitionedCall_1/gradients/MatVec/MatMul_grad/MatMul"}
}

