Skip to content
  • Michael Reneer's avatar
    Add `secure_sum` intrinsic to TFF. · 27807ecf
    Michael Reneer authored
    This change is part one of three:
    
    1. Add `secure_sum` intrinsic to TFF.
    2. Update transformations to support the `secure_sum` intrinsic.
    3. Update canonical form to support the `secure_sum` intrinsic.
    
    Here is an example using the API:
    
    ```python
      @computations.federated_computation
      def init_fn():
        """The `init` function for `computation_utils.IterativeProcess`."""
        return intrinsics.federated_value([0, 0], placements.SERVER)
    
      @computations.tf_computation(tf.int32, tf.int32], tf.int32]])
      def work(client_data, client_input):
        del client_data  # Unused
        del client_input  # Unused
        return [1, 1]
    
      @computations.federated_computation([
          computation_types.FederatedType([tf.int32, tf.int32], placements.SERVER),
          computation_types.FederatedType(tf.int32, placements.CLIENTS),
      ])
      def next_fn(server_state, client_data):
        """The `next` function for `computation_utils.IterativeProcess`."""
        client_input = intrinsics.federated_broadcast(server_state)
        c3 = intrinsics.federated_zip([client_data, client_input])
        client_updates = intrinsics.federated_map(work, c3)
        federated_update = intrinsics.federated_sum(client_updates[0])
        secure_update = intrinsics.secure_sum(client_updates[1], 8)
        new_server_state = intrinsics.federated_zip([federated_update, secure_update])
        server_output = intrinsics.federated_value([], placements.SERVER)
        return new_server_state, server_output
    
      ip = computation_utils.IterativeProcess(init_fn, next_fn)
    ```
    
    PiperOrigin-RevId: 290988239
    27807ecf