diff --git a/README.md b/README.md index 489c997..f9fdd1d 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,7 @@ $ python assert.py --use-cuda --causal --striped-ring-attn - [ ] add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl - [ ] figure out how to pytest distributed pytorch - [ ] use sdp context manager to validate when it is possible to use `ring_flash_attn_cuda`, otherwise assert out +- [ ] improvise a variant where each machine keeps compressed summary tokens, and one only ring pass those summary token for some given distance ## Citations