diff --git a/drivers/mtd/spi/sf_internal.h b/drivers/mtd/spi/sf_internal.h
index 16dd45bfc646b81bee04b351a29060eb9e36a19a..ed5c391dc2c6c34fc61fe605dae5f43ecb636663 100644
--- a/drivers/mtd/spi/sf_internal.h
+++ b/drivers/mtd/spi/sf_internal.h
@@ -227,7 +227,6 @@ void spi_flash_mtd_unregister(void);
 
 /**
  * spi_flash_scan - scan the SPI FLASH
- * @spi:	the spi slave structure
  * @flash:	the spi flash structure
  *
  * The drivers can use this fuction to scan the SPI FLASH.
@@ -236,6 +235,6 @@ void spi_flash_mtd_unregister(void);
  *
  * Return: 0 for success, others for failure.
  */
-int spi_flash_scan(struct spi_slave *spi, struct spi_flash *flash);
+int spi_flash_scan(struct spi_flash *flash);
 
 #endif /* _SF_INTERNAL_H_ */
diff --git a/drivers/mtd/spi/sf_ops.c b/drivers/mtd/spi/sf_ops.c
index 68f191b55e6d684f3df25fb804c4edbb1b934290..c065858be00244a87c98576f954aed07a7886f11 100644
--- a/drivers/mtd/spi/sf_ops.c
+++ b/drivers/mtd/spi/sf_ops.c
@@ -896,8 +896,9 @@ int spi_flash_decode_fdt(const void *blob, struct spi_flash *flash)
 }
 #endif /* CONFIG_IS_ENABLED(OF_CONTROL) */
 
-int spi_flash_scan(struct spi_slave *spi, struct spi_flash *flash)
+int spi_flash_scan(struct spi_flash *flash)
 {
+	struct spi_slave *spi = flash->spi;
 	const struct spi_flash_params *params;
 	u16 jedec, ext_jedec;
 	u8 idcode[5];
@@ -946,7 +947,6 @@ int spi_flash_scan(struct spi_slave *spi, struct spi_flash *flash)
 		write_sr(flash, 0);
 
 	/* Assign spi data */
-	flash->spi = spi;
 	flash->name = params->name;
 	flash->memory_map = spi->memory_map;
 	flash->dual_flash = flash->spi->option;
diff --git a/drivers/mtd/spi/sf_probe.c b/drivers/mtd/spi/sf_probe.c
index f8aad569a9394a183f2a666c97d269c4f8650b23..bf53eef5c212413878eddaa3c6411e1339ce7407 100644
--- a/drivers/mtd/spi/sf_probe.c
+++ b/drivers/mtd/spi/sf_probe.c
@@ -20,12 +20,12 @@
 /**
  * spi_flash_probe_slave() - Probe for a SPI flash device on a bus
  *
- * @spi: Bus to probe
  * @flashp: Pointer to place to put flash info, which may be NULL if the
  * space should be allocated
  */
-int spi_flash_probe_slave(struct spi_slave *spi, struct spi_flash *flash)
+int spi_flash_probe_slave(struct spi_flash *flash)
 {
+	struct spi_slave *spi = flash->spi;
 	int ret;
 
 	/* Setup spi_slave */
@@ -41,7 +41,7 @@ int spi_flash_probe_slave(struct spi_slave *spi, struct spi_flash *flash)
 		return ret;
 	}
 
-	ret = spi_flash_scan(spi, flash);
+	ret = spi_flash_scan(flash);
 	if (ret) {
 		ret = -EINVAL;
 		goto err_read_id;
@@ -68,7 +68,8 @@ struct spi_flash *spi_flash_probe_tail(struct spi_slave *bus)
 		return NULL;
 	}
 
-	if (spi_flash_probe_slave(bus, flash)) {
+	flash->spi = bus;
+	if (spi_flash_probe_slave(flash)) {
 		spi_free_slave(bus);
 		free(flash);
 		return NULL;
@@ -152,8 +153,9 @@ int spi_flash_std_probe(struct udevice *dev)
 
 	flash = dev_get_uclass_priv(dev);
 	flash->dev = dev;
+	flash->spi = slave;
 	debug("%s: slave=%p, cs=%d\n", __func__, slave, plat->cs);
-	return spi_flash_probe_slave(slave, flash);
+	return spi_flash_probe_slave(flash);
 }
 
 static const struct dm_spi_flash_ops spi_flash_std_ops = {