sum-xor pairs


Write a function that accepts two non-negative integers s and x and returns all integer pairs (a, b) such that a + b = s and a xor b = x.

 $ cat sxp.rkt  
 #lang racket 
  
 (require racket/set) 
  
 (define (smallest-at-least-po2 a b) 
  
   ; Return the smallest power of 2 that is at least the given 
   ; values. 
    
   (define (bit-size i) 
  
     ; Return the smallest number of bits required to represent i. 
  
     (if (< i 2) 
       1 
       (let ((n (inexact->exact (ceiling (/ (log i) (log 2)))))) 
         (+ n (if (= (expt 2 n) i) 1 0))))) 
  
   (expt 2 (max (bit-size a) (bit-size b)))) 
  
  
 ; These functions treat numbers as n-bit values, where n is the 
 ; smallest number of bits capable of representing the sum and xor 
 ; values.  They also ignore overflow, which means it's possible 
 ; to have pair values that are larger than the sum value but 
 ; still add to the sum value.  Also, because xor and sum are 
 ; commutative, the number pairs are created with the smaller 
 ; value first (on the left, the car). 
  
  
 (define (sum-xor-pairs-n ab-sum ab-xor) 
  
   ; Return a set of number pairs such that, for each pair, the 
   ; sum of the two numbers equals ab-sum (ignoring overflow) and 
   ; the xor of the two numbers equlas ab-xor. 
  
   ; This function does work proportional to the power of the 
   ; larger of the smallest number of bits needed to represent the 
   ; sum and xor values, which is roughly linear in the larger of 
   ; the sum and xor values. 
    
   ; This function essentially implements the truth tables sum and 
   ; xor: 
   ; 
   ;   Ci  a  b  Co  s  x 
   ;    0  0  0   0  0  0 
   ;    0  0  1   0  1  1     
   ;    0  1  0   0  1  1     
   ;    0  1  1   1  0  0     
   ;    1  0  0   0  1  0     
   ;    1  0  1   1  0  1     
   ;    1  1  0   1  0  1     
   ;    1  1  1   1  1  0     
   ; 
   ;   Ci the carry-in bit                                  
   ;   a  a bit from one of the pair values               
   ;   b  the corresponding bit from the other pair value 
   ;   Co the carry-out from Ci + a + b                   
   ;   s  Ci + a + b                                      
   ;   x  a xor b                                         
   ; 
   ; Interesting things to note about this table: 
   ;   
   ;   The s and x bits are the same when Ci = 0. 
   ;   The s and x bits are different when Ci = 1. 
   ;   xor ignores Ci (and Co).   
  
  
   (define (add-msb a-msb b-msb ab-pairs) 
  
     ; Add the given most-significant bits to the given set of 
     ; number pairs; return the new set. 
  
     (for/set ([v ab-pairs]) 
       (vector 
         (cons a-msb (vector-ref v 0)) 
         (cons b-msb (vector-ref v 1))))) 
  
  
   (define (vector->pair ab-pairs) 
  
     ; Return a set of number pairs, where each number corresponds 
     ; to a bit list in the given pairs and each pair corresponds 
     ; to a vector in the given set.  The smaller value appears 
     ; first in the pair. 
  
     (define (implode bit-list) 
  
       ; Return the number equivalent to the given bit list (msb on 
       ; the left). 
  
       (let loop ((n 0) (bit-list bit-list)) 
         (if (null? bit-list) 
           n 
          (loop (+ (* n 2) (car bit-list)) (cdr bit-list))))) 
  
     (for/set ((v ab-pairs)) 
       (let ((a (implode (vector-ref v 0))) 
             (b (implode (vector-ref v 1)))) 
         (cons (min a b) (max a b))))) 
    
  
   (define (oops emsg) 
     (raise-arguments-error 
       'sum-xor-pairs-n "some unfathomable error")) 
  
  
   (vector->pair 
     (let loop 
  
       ((ab-sum ab-sum) 
        (ab-xor ab-xor) 
        (carry-in 0) 
        (ab-pairs (if (and (= ab-sum 0) (= ab-xor 0)) 
               (set #((0) (0)) #((1) (1))) 
                   (set #(() ()))))) 
  
       (if (and (= ab-sum 0) (= ab-xor 0)) 
  
         ab-pairs 
  
         (let 
  
           ((sum-bit (remainder ab-sum 2)) 
            (xor-bit (remainder ab-xor 2)) 
            (ab-sum (quotient ab-sum 2)) 
            (ab-xor (quotient ab-xor 2))) 
  
           (cond 
             ((= carry-in 0) 
               (cond 
                 ((and (= sum-bit 0) (= xor-bit 0)) 
                   (set-union 
                     (loop ab-sum ab-xor 0 (add-msb 0 0 ab-pairs)) 
                     (loop ab-sum ab-xor 1 (add-msb 1 1 ab-pairs)))) 
  
                 ((and (= sum-bit 1) (= xor-bit 1)) 
                   (set-union 
                     (loop ab-sum ab-xor 0 (add-msb 0 1 ab-pairs)) 
                     (loop ab-sum ab-xor 0 (add-msb 1 0 ab-pairs)))) 
  
                 ((not (= sum-bit xor-bit)) 
  
                   ; If the carry-in's zero, the sum of the two 
                   ; bits (ignoring carry-out) must equal the xor 
                   ; of the two bits.  Because that's not the 
                   ; case, there can be no solutions down this 
                   ; branch. 
  
                   (set)) 
  
                 (#t 
                   (oops)))) 
  
            ((= carry-in 1) 
              (cond 
                ((and (= sum-bit 1) (= xor-bit 0)) 
                  (set-union 
                    (loop ab-sum ab-xor 0 (add-msb 0 0 ab-pairs)) 
                    (loop ab-sum ab-xor 1 (add-msb 1 1 ab-pairs)))) 
  
                ((and (= sum-bit 0) (= xor-bit 1)) 
                  (set-union 
                    (loop ab-sum ab-xor 1 (add-msb 0 1 ab-pairs)) 
                    (loop ab-sum ab-xor 1 (add-msb 1 0 ab-pairs)))) 
  
                ((= sum-bit xor-bit) 
  
                  ; If the carry-in's one, the sum of the two 
                  ; bits (ignoring carry-out) cannot equal the 
                  ; xor of the two bits.  Because that's not the 
                  ; case, there can be no solutions down this 
                  ; branch. 
  
                 (set)) 
  
               (#t 
                (oops)))) 
         (#t 
           (oops)))))))) 
  
  
 (define (sum-xor-pairs-nsq ab-sum ab-xor) 
  
   ; Return a number-pair set such that, for each pair, the sum of 
   ; the two numbers equals ab-sum (ignoring overflow) and the xor 
   ; of the two numbers equlas ab-xor. 
  
   ; This function does work proportinal to ab-sum*ab-xor 
   ; (a.k.a. n-squared). 
  
   (let ((N (smallest-at-least-po2 ab-sum ab-xor))) 
  
     (let outer-loop ((a 0) (ab-pairs (set))) 
       (if (= a N) 
         ab-pairs 
         (let inner-loop ((b a) (ab-pairs ab-pairs)) 
           (if (= b N) 
             (outer-loop (+ a 1) ab-pairs) 
             (inner-loop 
              (+ b 1) (if (and (= (remainder (+ a b) N) ab-sum) 
                               (= (bitwise-xor a b) ab-xor)) 
                        (set-add ab-pairs (cons a b)) 
                        ab-pairs)))))))) 
  
 (sum-xor-pairs-n 9 5) 
  
  
 (require rackunit) 
  
 (define (check-sum-xor-pairs n) 
    
   (define (check-sum-xor-list ab-pairs ab-sum ab-xor) 
      
     (define N (smallest-at-least-po2 ab-sum ab-xor)) 
      
     (define (check-sum-xor a b) 
       (check-eq? (remainder (+ a b) N) ab-sum) 
       (check-eq? (bitwise-xor a b) ab-xor)) 
      
     (let loop ((ab-pairs ab-pairs)) 
       (if (set-empty? ab-pairs) 
         #t 
         (let ((p (set-first ab-pairs))) 
           (check-sum-xor (car p) (cdr p)) 
           (loop (set-rest ab-pairs)))))) 
    
   (do ((ab-sum 0 (+ 1 ab-sum))) ((> ab-sum n) #t) 
     (do ((ab-xor 0 (+ 1 ab-xor))) ((> ab-xor n) #t) 
       (let ((ab-pairs-n (sum-xor-pairs-n ab-sum ab-xor))) 
         (check-sum-xor-list ab-pairs-n ab-sum ab-xor) 
         (set=? ab-pairs-n (sum-xor-pairs-nsq ab-sum ab-xor)))))) 
  
 (check-sum-xor-pairs 100) 
  
  
 (define (time-it f n iters) 
  
   (define (run-test) 
     (do ((ab-sum 0 (+ 1 ab-sum))) ((> ab-sum n) #t) 
         (do ((ab-xor 0 (+ 1 ab-xor))) ((> ab-xor n) #t) 
           (f ab-sum ab-xor)))) 
    
   (let loop ((t 0) (i 0)) 
     (if (= i iters) 
       (inexact->exact (round (/ t iters)))       
       (let-values 
         (((a b c d) (time-apply run-test '()))) 
      (loop (+ t b) (+ i 1)))))) 
  
 (let ((iters 3)) 
   (do ((i 10 (+ 10 i))) ((> i 100) #t) 
     (printf 
      "sum-xor max: ~a, sum-xor-pairs-n: ~a, sum-xor-pairs-nsq: ~a\n" 
       i 
       (time-it sum-xor-pairs-n i iters) 
       (time-it sum-xor-pairs-nsq i iters)))) 
  
 $ mzscheme sxp.rkt  
 (set '(2 . 7) '(3 . 6) '(10 . 15) '(11 . 14)) 
 #t 
 sum-xor max: 10, sum-xor-pairs-n: 3, sum-xor-pairs-nsq: 1 
 sum-xor max: 20, sum-xor-pairs-n: 7, sum-xor-pairs-nsq: 3 
 sum-xor max: 30, sum-xor-pairs-n: 17, sum-xor-pairs-nsq: 9 
 sum-xor max: 40, sum-xor-pairs-n: 32, sum-xor-pairs-nsq: 39 
 sum-xor max: 50, sum-xor-pairs-n: 52, sum-xor-pairs-nsq: 80 
 sum-xor max: 60, sum-xor-pairs-n: 73, sum-xor-pairs-nsq: 128 
 sum-xor max: 70, sum-xor-pairs-n: 107, sum-xor-pairs-nsq: 296 
 sum-xor max: 80, sum-xor-pairs-n: 145, sum-xor-pairs-nsq: 547 
 sum-xor max: 90, sum-xor-pairs-n: 188, sum-xor-pairs-nsq: 827 
 sum-xor max: 100, sum-xor-pairs-n: 236, sum-xor-pairs-nsq: 1128 
 #t 
  
 $