The error message you get actually tells you the fix, but I found it non-intuitive to what I was doing enough I was hesitant to actually just try the config:
torch._dynamo.config.capture_dynamic_output_shape_ops = True
The general issue is capturing shapes on scalars isn’t turned on by default due to various issues, but for your case it may actually work. It is also interesting to see where TorchVision hit this, and worked around with torch.where instead.