from net_3tensor access *;

pen mode_label_pen = magenta + fontsize(7pt);

// STEP 1 ==============================================================================

int k = 5;
int wire_gap=12;
int wire_height=32;

picture P1;
network N1 = three_tensor(k=k, wire_gap=wire_gap, gap=25, wire_height=wire_height, attach_input=true);

for (int i = 0; i < k; ++i) {
  N1.tensors[3*i].label="$\alpha$";
  N1.tensors[3*i+1].label="$\beta$";
  N1.tensors[3*i+2].label="$\gamma$";
}

N1.tensors[3*k].label="$A$";
N1.tensors[3*k+1].label="$B$";

N1.draw(P1);

for (int i = 0; i < k; ++i) {
  int idx = i+1;
  label(P1, "$i'_{" + string(idx) + "}$", N1.tensors[0].mid + (wire_gap*i-8, -37), mode_label_pen);
  label(P1, "$k_{" + string(idx) + "}$", N1.tensors[3*(k-1)].mid + (-wire_gap*(k-i-2)-3, -37), mode_label_pen);

  label(P1, "$k'_{" + string(idx) + "}$", N1.tensors[1].mid + (wire_gap*i-8, -37), mode_label_pen);
  label(P1, "$j'_{" + string(idx) + "}$", N1.tensors[3*(k-1)+1].mid + (-wire_gap*(k-i-2)-3, -37), mode_label_pen);

  label(P1, "$i_{" + string(idx) + "}$", N1.tensors[2].mid + (wire_gap*i-7.5, 37), mode_label_pen);
  label(P1, "$j_{" + string(idx) + "}$", N1.tensors[3*(k-1)+2].mid + (-wire_gap*(k-i-2)-3, 37), mode_label_pen);

  label(P1, "$\ell_{" + string(idx) + "}$", N1.mode_joins[i].mid + (5,4), mode_label_pen);
}

// STEP 2 ==============================================================================

network N2 = three_tensor(k=k, wire_gap=wire_gap, gap=25, wire_height=wire_height, attach_input=true);
picture P2;
N2.draw(P2, draw_execution=true);

add(P1, (0,0));
add(P2, (300, 0));
