Blog ini akan memberi contoh kaedah untuk menggunakan kaedah 'torch.argmax()' dalam PyTorch.
Bagaimana untuk Menggunakan Kaedah 'torch.argmax()' dalam PyTorch?
Kaedah 'torch.argmax()' mengambil mana-mana tensor 1D atau 2D sebagai input dan mengembalikan tensor yang mengandungi indeks/indeks nilai maksimum sepanjang dimensi yang diberikan.
Sintaks kaedah 'torch.argmax()' diberikan di bawah:
obor. argmax ( < input_tensor > )
Untuk menggunakan kaedah ini dalam PyTorch, ikuti contoh berikut untuk pemahaman yang lebih baik:
Contoh 1: Gunakan Kaedah 'torch.argmax()' Dengan Tensor 1D
Dalam contoh pertama, kami akan mencipta tensor 1D dan menggunakan kaedah 'torch.argmax()' dengannya. Mari ikuti prosedur langkah demi langkah di bawah:
Langkah 1: Import Pustaka PyTorch
Pertama, import ' obor ” perpustakaan untuk menggunakan kaedah “torch.argmax()”:
import oborLangkah 2: Buat Tensor 1D
Kemudian, buat tensor 1D dan cetak elemennya. Di sini, kami mencipta yang berikut ' Sepuluh1 ” tensor daripada senarai menggunakan “ torch.tensor() fungsi ':
Sepuluh1 = obor. tensor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )
cetak ( Sepuluh1 )
Ini telah mencipta tensor 1D seperti yang dilihat di bawah:
Langkah 3: Cari Indeks Nilai Maksimum
Sekarang, gunakan ' torch.argmax() fungsi ” untuk mencari indeks/indeks bagi nilai maksimum dalam “ Sepuluh1 ” tensor:
T1_ind = obor. argmax ( Sepuluh1 )Langkah 4: Cetak Indeks Nilai Maksimum
Akhir sekali, paparkan indeks nilai maksimum dalam tensor input:
cetak ( 'Indeks:' , T1_ind )Output di bawah menunjukkan indeks nilai maksimum dalam ' Sepuluh1 ” tensor iaitu, 4. Ini bermakna nilai tensor yang paling tinggi adalah pada indeks ke-4 iaitu “ 9 ”:
Contoh 2: Gunakan Kaedah 'torch.argmax()' Dengan Tensor 2D
Dalam contoh kedua, kami akan mencipta tensor 2D dan menggunakan kaedah 'torch.argmax()' dengannya. Jom ikuti langkah yang disediakan:
Langkah 1: Import Pustaka PyTorch
Pertama, import ' obor ” perpustakaan untuk menggunakan kaedah “torch.argmax()”:
import oborLangkah 2: Buat Tensor 2D
Kemudian, gunakan ' torch.tensor() ” berfungsi untuk mencipta tensor 2D dan mencetak elemennya. Di sini, kami mencipta yang berikut ' Berpuluh2 “Tensor 2D:
Berpuluh2 = obor. tensor ( [ [ 4 , 1 , - 7 ] , [ lima belas , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )cetak ( Berpuluh2 )
Ini telah mencipta tensor 2D seperti yang dilihat di bawah:
Langkah 3: Cari Indeks Nilai Maksimum
Sekarang, cari indeks nilai maksimum dalam “ Berpuluh2 ” tensor dengan menggunakan “ torch.argmax() fungsi ':
T2_ind = obor. argmax ( Berpuluh2 )Langkah 4: Cetak Indeks Nilai Maksimum
Akhir sekali, paparkan indeks nilai maksimum dalam tensor input:
cetak ( 'Indeks:' , T2_ind )Menurut output di bawah, indeks nilai maksimum dalam ' Berpuluh2 ” tensor ialah “3”. Ini bermakna nilai tensor tertinggi adalah pada indeks ke-3 iaitu “ lima belas ”:
Langkah 5: Cari Indeks Nilai Maksimum Sepanjang Lajur
Selain itu, pengguna juga boleh mencari indeks/indeks nilai maksimum di sepanjang setiap lajur tensor. Sebagai contoh, kita boleh menggunakan ' malap=0 ” hujah dengan fungsi “torch.argmax()”. Ia mencari indeks nilai maksimum di sepanjang lajur dalam ' Berpuluh2 ” tensor dan kemudian mencetak indeks tersebut:
col_index = obor. argmax ( Berpuluh2 , malap = 0 )cetak ( 'Indeks dalam lajur:' , col_index )
Output di bawah menunjukkan indeks nilai maksimum di sepanjang setiap lajur tensor:
Langkah 6: Cari Indeks Nilai Maksimum Sepanjang Baris
Begitu juga, pengguna juga boleh mencari indeks/indeks nilai maksimum di sepanjang setiap baris tensor. Sebagai contoh, gunakan ' malap=1 ” argumen dengan fungsi “torch.argmax()” untuk mencari indeks nilai maksimum di sepanjang baris dalam tensor “Tens2” dan kemudian mencetak indeks tersebut:
row_index = obor. argmax ( Berpuluh2 , malap = 1 )cetak ( 'Indeks dalam baris:' , row_index )
Indeks nilai maksimum di sepanjang setiap baris tensor 'Tens2' boleh dilihat di bawah:
Kami telah menerangkan dengan cekap kaedah untuk menggunakan kaedah 'torch.argmax()' dalam PyTorch.
Catatan : Anda boleh mengakses Buku Nota Google Colab kami di sini pautan .
Kesimpulan
Untuk menggunakan kaedah 'torch.argmax()' dalam PyTorch, pertama, import ' obor ” perpustakaan. Kemudian, cipta tensor 1D atau 2D yang diingini dan lihat elemennya. Seterusnya, gunakan ' torch.argmax() ” kaedah untuk mencari/mengira indeks/indeks bagi nilai maksimum dalam tensor. Selain itu, pengguna juga boleh mencari indeks nilai maksimum di sepanjang setiap baris atau lajur dalam tensor menggunakan ' malap ” hujah. Akhir sekali, paparkan indeks nilai maksimum dalam tensor input. Blog ini telah memberi contoh kaedah untuk menggunakan kaedah 'torch.argmax()' dalam PyTorch.