Merge "Fix a case when brotli writer fails to write last few blocks of data" into oc-mr1-dev
diff --git a/tests/component/updater_test.cpp b/tests/component/updater_test.cpp
index 01b86f2..6c341c1 100644
--- a/tests/component/updater_test.cpp
+++ b/tests/component/updater_test.cpp
@@ -592,10 +592,10 @@
   ASSERT_EQ(0, zip_writer.StartEntry("new.dat.br", 0));
 
   auto generator = []() { return rand() % 128; };
-  // Generate 2048 blocks of random data.
+  // Generate 100 blocks of random data.
   std::string brotli_new_data;
-  brotli_new_data.reserve(4096 * 2048);
-  generate_n(back_inserter(brotli_new_data), 4096 * 2048, generator);
+  brotli_new_data.reserve(4096 * 100);
+  generate_n(back_inserter(brotli_new_data), 4096 * 100, generator);
 
   size_t encoded_size = BrotliEncoderMaxCompressedSize(brotli_new_data.size());
   std::vector<uint8_t> encoded_data(encoded_size);
@@ -609,8 +609,19 @@
   ASSERT_EQ(0, zip_writer.StartEntry("patch_data", 0));
   ASSERT_EQ(0, zip_writer.FinishEntry());
 
+  // Write a few small chunks of new data, then a large chunk, and finally a few small chunks.
+  // This helps us to catch potential short writes.
   std::vector<std::string> transfer_list = {
-    "4", "2048", "0", "0", "new 4,0,512,512,1024", "new 2,1024,2048",
+    "4",
+    "100",
+    "0",
+    "0",
+    "new 2,0,1",
+    "new 2,1,2",
+    "new 4,2,50,50,97",
+    "new 2,97,98",
+    "new 2,98,99",
+    "new 2,99,100",
   };
   ASSERT_EQ(0, zip_writer.StartEntry("transfer_list", 0));
   std::string commands = android::base::Join(transfer_list, '\n');
diff --git a/updater/blockimg.cpp b/updater/blockimg.cpp
index 2bec487..a0b9ad2 100644
--- a/updater/blockimg.cpp
+++ b/updater/blockimg.cpp
@@ -158,20 +158,22 @@
     CHECK_NE(tgt.size(), static_cast<size_t>(0));
   };
 
-  virtual ~RangeSinkWriter() {};
-
   bool Finished() const {
     return next_range_ == tgt_.size() && current_range_left_ == 0;
   }
 
-  // Return number of bytes consumed; and 0 indicates a writing failure.
-  virtual size_t Write(const uint8_t* data, size_t size) {
+  size_t AvailableSpace() const {
+    return tgt_.blocks() * BLOCKSIZE - bytes_written_;
+  }
+
+  // Return number of bytes written; and 0 indicates a writing failure.
+  size_t Write(const uint8_t* data, size_t size) {
     if (Finished()) {
       LOG(ERROR) << "range sink write overrun; can't write " << size << " bytes";
       return 0;
     }
 
-    size_t consumed = 0;
+    size_t written = 0;
     while (size > 0) {
       // Move to the next range as needed.
       if (!SeekToOutputRange()) {
@@ -191,18 +193,18 @@
       size -= write_now;
 
       current_range_left_ -= write_now;
-      consumed += write_now;
+      written += write_now;
     }
 
-    bytes_written_ += consumed;
-    return consumed;
+    bytes_written_ += written;
+    return written;
   }
 
   size_t BytesWritten() const {
     return bytes_written_;
   }
 
- protected:
+ private:
   // Set up the output cursor, move to next range if needed.
   bool SeekToOutputRange() {
     // We haven't finished the current range yet.
@@ -241,75 +243,6 @@
   size_t bytes_written_;
 };
 
-class BrotliNewDataWriter : public RangeSinkWriter {
- public:
-  BrotliNewDataWriter(int fd, const RangeSet& tgt, BrotliDecoderState* state)
-      : RangeSinkWriter(fd, tgt), state_(state) {}
-
-  size_t Write(const uint8_t* data, size_t size) override {
-    if (Finished()) {
-      LOG(ERROR) << "Brotli new data write overrun; can't write " << size << " bytes";
-      return 0;
-    }
-    CHECK(state_ != nullptr);
-
-    size_t consumed = 0;
-    while (true) {
-      // Move to the next range as needed.
-      if (!SeekToOutputRange()) {
-        break;
-      }
-
-      size_t available_in = size;
-      size_t write_now = std::min<size_t>(32768, current_range_left_);
-      uint8_t buffer[write_now];
-
-      size_t available_out = write_now;
-      uint8_t* next_out = buffer;
-
-      // The brotli decoder will update |data|, |available_in|, |next_out| and |available_out|.
-      BrotliDecoderResult result = BrotliDecoderDecompressStream(
-          state_, &available_in, &data, &available_out, &next_out, nullptr);
-
-      // We don't have a way to recover from the decode error; report the failure.
-      if (result == BROTLI_DECODER_RESULT_ERROR) {
-        LOG(ERROR) << "Decompression failed with "
-                   << BrotliDecoderErrorString(BrotliDecoderGetErrorCode(state_));
-        return 0;
-      }
-
-      if (write_all(fd_, buffer, write_now - available_out) == -1) {
-        return 0;
-      }
-
-      LOG(DEBUG) << "bytes written: " << write_now - available_out << ", bytes consumed "
-                 << size - available_in << ", decoder status " << result;
-
-      // Update the total bytes written to output by the current writer; this is different from the
-      // consumed input bytes.
-      bytes_written_ += write_now - available_out;
-      current_range_left_ -= (write_now - available_out);
-      consumed += (size - available_in);
-
-      // Update the remaining size. The input data ptr is already updated by brotli decoder
-      // function.
-      size = available_in;
-
-      // Continue if we have more output to write, or more input to consume.
-      if (result == BROTLI_DECODER_RESULT_SUCCESS ||
-          (result == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT && size == 0)) {
-        break;
-      }
-    }
-
-    return consumed;
-  }
-
- private:
-  // Pointer to the decoder state. (initialized by PerformBlockImageUpdate)
-  BrotliDecoderState* state_;
-};
-
 /**
  * All of the data for all the 'new' transfers is contained in one file in the update package,
  * concatenated together in the order in which transfers.list will need it. We want to stream it out
@@ -354,16 +287,73 @@
 
     // At this point nti->writer is set, and we own it. The main thread is waiting for it to
     // disappear from nti.
-    size_t consumed = nti->writer->Write(data, size);
-
-    // We encounter a fatal error if we fail to consume any input bytes. If this happens, abort the
-    // extraction.
-    if (consumed == 0) {
-      LOG(ERROR) << "Failed to process " << size << " input bytes.";
+    size_t write_now = std::min(size, nti->writer->AvailableSpace());
+    if (nti->writer->Write(data, write_now) != write_now) {
+      LOG(ERROR) << "Failed to write " << write_now << " bytes.";
       return false;
     }
-    data += consumed;
-    size -= consumed;
+
+    data += write_now;
+    size -= write_now;
+
+    if (nti->writer->Finished()) {
+      // We have written all the bytes desired by this writer.
+
+      pthread_mutex_lock(&nti->mu);
+      nti->writer = nullptr;
+      pthread_cond_broadcast(&nti->cv);
+      pthread_mutex_unlock(&nti->mu);
+    }
+  }
+
+  return true;
+}
+
+static bool receive_brotli_new_data(const uint8_t* data, size_t size, void* cookie) {
+  NewThreadInfo* nti = static_cast<NewThreadInfo*>(cookie);
+
+  while (size > 0 || BrotliDecoderHasMoreOutput(nti->brotli_decoder_state)) {
+    // Wait for nti->writer to be non-null, indicating some of this data is wanted.
+    pthread_mutex_lock(&nti->mu);
+    while (nti->writer == nullptr) {
+      pthread_cond_wait(&nti->cv, &nti->mu);
+    }
+    pthread_mutex_unlock(&nti->mu);
+
+    // At this point nti->writer is set, and we own it. The main thread is waiting for it to
+    // disappear from nti.
+
+    size_t buffer_size = std::min<size_t>(32768, nti->writer->AvailableSpace());
+    if (buffer_size == 0) {
+      LOG(ERROR) << "No space left in output range";
+      return false;
+    }
+    uint8_t buffer[buffer_size];
+    size_t available_in = size;
+    size_t available_out = buffer_size;
+    uint8_t* next_out = buffer;
+
+    // The brotli decoder will update |data|, |available_in|, |next_out| and |available_out|.
+    BrotliDecoderResult result = BrotliDecoderDecompressStream(
+        nti->brotli_decoder_state, &available_in, &data, &available_out, &next_out, nullptr);
+
+    if (result == BROTLI_DECODER_RESULT_ERROR) {
+      LOG(ERROR) << "Decompression failed with "
+                 << BrotliDecoderErrorString(BrotliDecoderGetErrorCode(nti->brotli_decoder_state));
+      return false;
+    }
+
+    LOG(DEBUG) << "bytes to write: " << buffer_size - available_out << ", bytes consumed "
+               << size - available_in << ", decoder status " << result;
+
+    size_t write_now = buffer_size - available_out;
+    if (nti->writer->Write(buffer, write_now) != write_now) {
+      LOG(ERROR) << "Failed to write " << write_now << " bytes.";
+      return false;
+    }
+
+    // Update the remaining size. The input data ptr is already updated by brotli decoder function.
+    size = available_in;
 
     if (nti->writer->Finished()) {
       // We have written all the bytes desired by this writer.
@@ -380,8 +370,11 @@
 
 static void* unzip_new_data(void* cookie) {
   NewThreadInfo* nti = static_cast<NewThreadInfo*>(cookie);
-  ProcessZipEntryContents(nti->za, &nti->entry, receive_new_data, nti);
-
+  if (nti->brotli_compressed) {
+    ProcessZipEntryContents(nti->za, &nti->entry, receive_brotli_new_data, nti);
+  } else {
+    ProcessZipEntryContents(nti->za, &nti->entry, receive_new_data, nti);
+  }
   pthread_mutex_lock(&nti->mu);
   nti->receiver_available = false;
   if (nti->writer != nullptr) {
@@ -1240,12 +1233,7 @@
     LOG(INFO) << " writing " << tgt.blocks() << " blocks of new data";
 
     pthread_mutex_lock(&params.nti.mu);
-    if (params.nti.brotli_compressed) {
-      params.nti.writer =
-          std::make_unique<BrotliNewDataWriter>(params.fd, tgt, params.nti.brotli_decoder_state);
-    } else {
-      params.nti.writer = std::make_unique<RangeSinkWriter>(params.fd, tgt);
-    }
+    params.nti.writer = std::make_unique<RangeSinkWriter>(params.fd, tgt);
     pthread_cond_broadcast(&params.nti.cv);
 
     while (params.nti.writer != nullptr) {
@@ -1485,7 +1473,6 @@
   if (params.canwrite) {
     params.nti.za = za;
     params.nti.entry = new_entry;
-    // The entry is compressed by brotli if has a 'br' extension.
     params.nti.brotli_compressed = android::base::EndsWith(new_data_fn->data, ".br");
     if (params.nti.brotli_compressed) {
       // Initialize brotli decoder state.