Training infrastructure improvements for production use #1

Open
opened 2026-03-23 13:10:58 +01:00 by catboi · 2 comments

Suggested Improvements

Based on training runs, here are infrastructure optimizations that would significantly improve training speed and reliability:

1. GPU Acceleration Verification

  • DJL/PyTorch dependencies are included but GPU usage is untested
  • 25+ hours on CPU vs minutes on GPU
  • Need to verify CUDA tensors are actually being used

2. Quantized Activations

  • Currently uses float32 for activations even with ternary weights
  • BFloat16 or Int8 for activations would speed up CPU inference significantly

3. Gradient Checkpointing

  • Recompute activations instead of storing them
  • Allows larger batch sizes and longer sequences

4. Learning Rate Scheduling

  • Current lr=0.0005 is static
  • Warmup + cosine decay would improve convergence

5. Mixed Precision Training

  • Float16 for forward pass, float32 for gradient accumulation
  • Faster on CPUs with AVX2/AVX512 vector units

6. Early Stopping

  • No validation loss monitoring during training
  • Wasted epochs when model has converged

7. Byte-Level BPE Tokenizer

  • Between char and word level
  • Smaller sequence lengths, better semantics
  • Standard in modern GPT models

8. Efficient Attention

  • Flash Attention implementation
  • Or sliding window attention for long contexts

9. Streaming Data Pipeline

  • Currently loads all text into memory
  • Streaming loader + prefetch would handle large corpora better

10. Checkpoint Saving

  • Only saves at end of training
  • Checkpoint per epoch to avoid losing progress on crashes

The core BitNet architecture is solid. These are infrastructure/optimization improvements.

Priority: Low (architecture works, just needs optimization)

## Suggested Improvements Based on training runs, here are infrastructure optimizations that would significantly improve training speed and reliability: ### 1. GPU Acceleration Verification - DJL/PyTorch dependencies are included but GPU usage is untested - 25+ hours on CPU vs minutes on GPU - Need to verify CUDA tensors are actually being used ### 2. Quantized Activations - Currently uses float32 for activations even with ternary weights - BFloat16 or Int8 for activations would speed up CPU inference significantly ### 3. Gradient Checkpointing - Recompute activations instead of storing them - Allows larger batch sizes and longer sequences ### 4. Learning Rate Scheduling - Current lr=0.0005 is static - Warmup + cosine decay would improve convergence ### 5. Mixed Precision Training - Float16 for forward pass, float32 for gradient accumulation - Faster on CPUs with AVX2/AVX512 vector units ### 6. Early Stopping - No validation loss monitoring during training - Wasted epochs when model has converged ### 7. Byte-Level BPE Tokenizer - Between char and word level - Smaller sequence lengths, better semantics - Standard in modern GPT models ### 8. Efficient Attention - Flash Attention implementation - Or sliding window attention for long contexts ### 9. Streaming Data Pipeline - Currently loads all text into memory - Streaming loader + prefetch would handle large corpora better ### 10. Checkpoint Saving - Only saves at end of training - Checkpoint per epoch to avoid losing progress on crashes The core BitNet architecture is solid. These are infrastructure/optimization improvements. Priority: Low (architecture works, just needs optimization)
LeNooby09 added reference master 2026-03-23 13:15:12 +01:00
Owner
should be addressed in [7c03f22e435478277215ec58fcdfedb61e193f5b](https://git.lenooby09.tech/LeNooby09/neural-bit/commit/7c03f22e435478277215ec58fcdfedb61e193f5b)
Author

Additional Feedback on Latest Commits

Nice work adding checkpointing, early stopping, LR scheduling, and streaming dataset. Two items need attention:

BPETokenizer.encode() - O(n*m) complexity issue

Current implementation:

for ((left, right) in merges) {
    tokens = mergePair(tokens, left, right, ...)
}

Each mergePair call iterates through the entire token list. With n=tokens and m=merges, this is O(n*m) for encoding.

Fix: Precompute a merge table (HashMap) at construction time. For each token pair, store which merge applies. Then encode in a single pass.

Alternative: Cache the merge result for common byte sequences.

StreamingTextDataset.get(index) - ignores index parameter

override fun get(index: Int): Pair<FloatTensor, FloatTensor> {
    // index is ignored - returns next from buffer
    return buffer.poll(...) 
}

This breaks the Dataset contract where get(0) and get(1) should return distinct samples. Currently calling get(0) twice returns different results.

Fix options:

  1. If true streaming is needed, implement a ring buffer with proper index tracking
  2. If random access is needed, use TextDataset instead of StreamingTextDataset for that use case
  3. Document that get() ignores index and behaves like an iterator
## Additional Feedback on Latest Commits Nice work adding checkpointing, early stopping, LR scheduling, and streaming dataset. Two items need attention: ### BPETokenizer.encode() - O(n*m) complexity issue Current implementation: ```kotlin for ((left, right) in merges) { tokens = mergePair(tokens, left, right, ...) } ``` Each `mergePair` call iterates through the entire token list. With n=tokens and m=merges, this is O(n*m) for encoding. **Fix**: Precompute a merge table (HashMap) at construction time. For each token pair, store which merge applies. Then encode in a single pass. Alternative: Cache the merge result for common byte sequences. ### StreamingTextDataset.get(index) - ignores index parameter ```kotlin override fun get(index: Int): Pair<FloatTensor, FloatTensor> { // index is ignored - returns next from buffer return buffer.poll(...) } ``` This breaks the Dataset contract where get(0) and get(1) should return distinct samples. Currently calling get(0) twice returns different results. **Fix options**: 1. If true streaming is needed, implement a ring buffer with proper index tracking 2. If random access is needed, use TextDataset instead of StreamingTextDataset for that use case 3. Document that get() ignores index and behaves like an iterator
Sign in to join this conversation.
No labels
No milestone
No project
No assignees
2 participants
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference
LeNooby09/neural-bit#1
No description provided.