diff --git a/drivers/mtd/spi/spi_flash.c b/drivers/mtd/spi/spi_flash.c
index 17f3d3cb147c1e1734a38a3fc26f95c0acc83c12..b82011d0fd2833301cc4b5df12ce0b1e4edd49e7 100644
--- a/drivers/mtd/spi/spi_flash.c
+++ b/drivers/mtd/spi/spi_flash.c
@@ -87,6 +87,9 @@ int spi_flash_cmd_write_multi(struct spi_flash *flash, u32 offset,
 	for (actual = 0; actual < len; actual += chunk_len) {
 		chunk_len = min(len - actual, page_size - byte_addr);
 
+		if (flash->spi->max_write_size)
+			chunk_len = min(chunk_len, flash->spi->max_write_size);
+
 		cmd[1] = page_addr >> 8;
 		cmd[2] = page_addr;
 		cmd[3] = byte_addr;
@@ -111,8 +114,11 @@ int spi_flash_cmd_write_multi(struct spi_flash *flash, u32 offset,
 		if (ret)
 			break;
 
-		page_addr++;
-		byte_addr = 0;
+		byte_addr += chunk_len;
+		if (byte_addr == page_size) {
+			page_addr++;
+			byte_addr = 0;
+		}
 	}
 
 	debug("SF: program %s %zu bytes @ %#x\n",