Bagaimana untuk Menggunakan Kaedah 'torch.argmax()' dalam PyTorch?

Bagaimana Untuk Menggunakan Kaedah Torch Argmax Dalam Pytorch



Dalam PyTorch, ' torch.argmax() ” kaedah ialah fungsi terbina dalam yang mengembalikan indeks nilai maksimum tensor tertentu merentas dimensi tertentu. Pengguna menggunakan fungsi ini apabila mereka bekerja dengan tensor dan ingin mencari indeks nilai maksimum di sepanjang dimensi yang diberikan tensor. Selain itu, kaedah ini juga boleh berguna untuk pengelasan di mana pengguna ingin mengetahui kelas mana yang mempunyai kebarangkalian tertinggi.

Blog ini akan memberi contoh kaedah untuk menggunakan kaedah 'torch.argmax()' dalam PyTorch.

Bagaimana untuk Menggunakan Kaedah 'torch.argmax()' dalam PyTorch?

Kaedah 'torch.argmax()' mengambil mana-mana tensor 1D atau 2D sebagai input dan mengembalikan tensor yang mengandungi indeks/indeks nilai maksimum sepanjang dimensi yang diberikan.







Sintaks kaedah 'torch.argmax()' diberikan di bawah:



obor. argmax ( < input_tensor > )

Untuk menggunakan kaedah ini dalam PyTorch, ikuti contoh berikut untuk pemahaman yang lebih baik:



Contoh 1: Gunakan Kaedah 'torch.argmax()' Dengan Tensor 1D

Dalam contoh pertama, kami akan mencipta tensor 1D dan menggunakan kaedah 'torch.argmax()' dengannya. Mari ikuti prosedur langkah demi langkah di bawah:





Langkah 1: Import Pustaka PyTorch

Pertama, import ' obor ” perpustakaan untuk menggunakan kaedah “torch.argmax()”:

import obor

Langkah 2: Buat Tensor 1D

Kemudian, buat tensor 1D dan cetak elemennya. Di sini, kami mencipta yang berikut ' Sepuluh1 ” tensor daripada senarai menggunakan “ torch.tensor() fungsi ':



Sepuluh1 = obor. tensor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )

cetak ( Sepuluh1 )

Ini telah mencipta tensor 1D seperti yang dilihat di bawah:

Langkah 3: Cari Indeks Nilai Maksimum

Sekarang, gunakan ' torch.argmax() fungsi ” untuk mencari indeks/indeks bagi nilai maksimum dalam “ Sepuluh1 ” tensor:

T1_ind = obor. argmax ( Sepuluh1 )

Langkah 4: Cetak Indeks Nilai Maksimum

Akhir sekali, paparkan indeks nilai maksimum dalam tensor input:

cetak ( 'Indeks:' , T1_ind )

Output di bawah menunjukkan indeks nilai maksimum dalam ' Sepuluh1 ” tensor iaitu, 4. Ini bermakna nilai tensor yang paling tinggi adalah pada indeks ke-4 iaitu “ 9 ”:

Contoh 2: Gunakan Kaedah 'torch.argmax()' Dengan Tensor 2D

Dalam contoh kedua, kami akan mencipta tensor 2D dan menggunakan kaedah 'torch.argmax()' dengannya. Jom ikuti langkah yang disediakan:

Langkah 1: Import Pustaka PyTorch

Pertama, import ' obor ” perpustakaan untuk menggunakan kaedah “torch.argmax()”:

import obor

Langkah 2: Buat Tensor 2D

Kemudian, gunakan ' torch.tensor() ” berfungsi untuk mencipta tensor 2D dan mencetak elemennya. Di sini, kami mencipta yang berikut ' Berpuluh2 “Tensor 2D:

Berpuluh2 = obor. tensor ( [ [ 4 , 1 , - 7 ] , [ lima belas , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )

cetak ( Berpuluh2 )

Ini telah mencipta tensor 2D seperti yang dilihat di bawah:

Langkah 3: Cari Indeks Nilai Maksimum

Sekarang, cari indeks nilai maksimum dalam “ Berpuluh2 ” tensor dengan menggunakan “ torch.argmax() fungsi ':

T2_ind = obor. argmax ( Berpuluh2 )

Langkah 4: Cetak Indeks Nilai Maksimum

Akhir sekali, paparkan indeks nilai maksimum dalam tensor input:

cetak ( 'Indeks:' , T2_ind )

Menurut output di bawah, indeks nilai maksimum dalam ' Berpuluh2 ” tensor ialah “3”. Ini bermakna nilai tensor tertinggi adalah pada indeks ke-3 iaitu “ lima belas ”:

Langkah 5: Cari Indeks Nilai Maksimum Sepanjang Lajur

Selain itu, pengguna juga boleh mencari indeks/indeks nilai maksimum di sepanjang setiap lajur tensor. Sebagai contoh, kita boleh menggunakan ' malap=0 ” hujah dengan fungsi “torch.argmax()”. Ia mencari indeks nilai maksimum di sepanjang lajur dalam ' Berpuluh2 ” tensor dan kemudian mencetak indeks tersebut:

col_index = obor. argmax ( Berpuluh2 , malap = 0 )

cetak ( 'Indeks dalam lajur:' , col_index )

Output di bawah menunjukkan indeks nilai maksimum di sepanjang setiap lajur tensor:

Langkah 6: Cari Indeks Nilai Maksimum Sepanjang Baris

Begitu juga, pengguna juga boleh mencari indeks/indeks nilai maksimum di sepanjang setiap baris tensor. Sebagai contoh, gunakan ' malap=1 ” argumen dengan fungsi “torch.argmax()” untuk mencari indeks nilai maksimum di sepanjang baris dalam tensor “Tens2” dan kemudian mencetak indeks tersebut:

row_index = obor. argmax ( Berpuluh2 , malap = 1 )

cetak ( 'Indeks dalam baris:' , row_index )

Indeks nilai maksimum di sepanjang setiap baris tensor 'Tens2' boleh dilihat di bawah:

Kami telah menerangkan dengan cekap kaedah untuk menggunakan kaedah 'torch.argmax()' dalam PyTorch.

Catatan : Anda boleh mengakses Buku Nota Google Colab kami di sini pautan .

Kesimpulan

Untuk menggunakan kaedah 'torch.argmax()' dalam PyTorch, pertama, import ' obor ” perpustakaan. Kemudian, cipta tensor 1D atau 2D yang diingini dan lihat elemennya. Seterusnya, gunakan ' torch.argmax() ” kaedah untuk mencari/mengira indeks/indeks bagi nilai maksimum dalam tensor. Selain itu, pengguna juga boleh mencari indeks nilai maksimum di sepanjang setiap baris atau lajur dalam tensor menggunakan ' malap ” hujah. Akhir sekali, paparkan indeks nilai maksimum dalam tensor input. Blog ini telah memberi contoh kaedah untuk menggunakan kaedah 'torch.argmax()' dalam PyTorch.