HloModule jit_loss__540.66

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_lhs = f32[] parameter(0)
  %scalar_rhs = f32[] parameter(1)
  ROOT %add.2 = f32[] add(f32[] %scalar_lhs, f32[] %scalar_rhs)
}

%fused_computation (param_0.5: f32[5], param_1.10: f32[5], param_2.6: f32[5], param_3.5: f32[5,5]) -> f32[5,5] {
  %constant_16 = f32[] constant(2), metadata={op_type="mul" op_name="jit(loss)/jit(jvp(loss))/mul" source_file="<ipython-input-73-db04383b747a>" source_line=8}
  %broadcast.6 = f32[5]{0} broadcast(f32[] %constant_16), dimensions={}, metadata={op_type="mul" op_name="jit(loss)/jit(jvp(loss))/mul" source_file="<ipython-input-73-db04383b747a>" source_line=8}
  %param_3.5 = f32[5,5]{1,0} parameter(3)
  %param_2.6 = f32[5]{0} parameter(2)
  %broadcast.10 = f32[5,5]{1,0} broadcast(f32[5]{0} %param_2.6), dimensions={1}
  %multiply.9 = f32[5,5]{1,0} multiply(f32[5,5]{1,0} %param_3.5, f32[5,5]{1,0} %broadcast.10)
  %constant_20 = f32[] constant(0)
  %reduce.4 = f32[5]{0} reduce(f32[5,5]{1,0} %multiply.9, f32[] %constant_20), dimensions={1}, to_apply=%scalar_add_computation, metadata={op_type="dot_general" op_name="jit(loss)/jit(jvp(loss))/jit(jvp(fn1))/dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n                                                    precision=None ]" source_file="<ipython-input-73-db04383b747a>" source_line=3}
  %param_1.10 = f32[5]{0} parameter(1)
  %add.3 = f32[5]{0} add(f32[5]{0} %reduce.4, f32[5]{0} %param_1.10), metadata={op_type="add" op_name="jit(loss)/jit(jvp(loss))/jit(jvp(fn1))/add" source_file="<ipython-input-73-db04383b747a>" source_line=3}
  %param_0.5 = f32[5]{0} parameter(0)
  %subtract.1 = f32[5]{0} subtract(f32[5]{0} %add.3, f32[5]{0} %param_0.5), metadata={op_type="sub" op_name="jit(loss)/jit(jvp(loss))/sub" source_file="<ipython-input-73-db04383b747a>" source_line=8}
  %multiply.6 = f32[5]{0} multiply(f32[5]{0} %broadcast.6, f32[5]{0} %subtract.1), metadata={op_type="mul" op_name="jit(loss)/jit(jvp(loss))/mul" source_file="<ipython-input-73-db04383b747a>" source_line=8}
  %broadcast.5 = f32[5,5]{1,0} broadcast(f32[5]{0} %multiply.6), dimensions={0}
  ROOT %multiply.5 = f32[5,5]{1,0} multiply(f32[5,5]{1,0} %broadcast.5, f32[5,5]{1,0} %broadcast.10), metadata={op_type="dot_general" op_name="jit(loss)/jit(transpose(jvp(loss)))/jit(transpose(jvp(fn1)))/dot_general[ dimension_numbers=(((), ()), ((), ()))\n                                                                          precision=None ]" source_file="<ipython-input-73-db04383b747a>" source_line=3}
}

ENTRY %jit_loss__540.66 (parameter.1: f32[5,5], parameter.2: f32[5], parameter.3: f32[5], parameter.4: f32[5]) -> (f32[5,5]) {
  %parameter.4 = f32[5]{0} parameter(3)
  %parameter.3 = f32[5]{0} parameter(2)
  %parameter.2 = f32[5]{0} parameter(1), metadata={op_type="xla_call" op_name="jit(loss)/jit(jvp(loss))/xla_call[ backend=None\n                                   device=None\n                                   donated_invars=(False, False, False, False)\n                                   name=jvp(fn1) ]" source_file="<ipython-input-73-db04383b747a>" source_line=7}
  %parameter.1 = f32[5,5]{1,0} parameter(0)
  %fusion = f32[5,5]{1,0} fusion(f32[5]{0} %parameter.4, f32[5]{0} %parameter.3, f32[5]{0} %parameter.2, f32[5,5]{1,0} %parameter.1), kind=kLoop, calls=%fused_computation, metadata={op_type="dot_general" op_name="jit(loss)/jit(transpose(jvp(loss)))/jit(transpose(jvp(fn1)))/dot_general[ dimension_numbers=(((), ()), ((), ()))\n                                                                          precision=None ]" source_file="<ipython-input-73-db04383b747a>" source_line=3}
  ROOT %tuple.65 = (f32[5,5]{1,0}) tuple(f32[5,5]{1,0} %fusion)
}

